File size: 4,862 Bytes
792ad00
b6e32c9
792ad00
 
b6e32c9
792ad00
 
 
 
b6e32c9
 
792ad00
 
 
 
 
b6e32c9
 
 
 
792ad00
b6e32c9
 
792ad00
b6e32c9
792ad00
 
 
 
 
 
b6e32c9
792ad00
b6e32c9
 
792ad00
 
b6e32c9
 
 
 
 
 
 
792ad00
 
b6e32c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f281e
c8065f3
77f281e
 
 
 
 
 
 
 
b6e32c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792ad00
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e32c9
792ad00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import logging
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List
from datetime import datetime

from api.auth import get_current_user
from models import db_models
from models.schemas import MindMapGenerateRequest, MindMapResponse
from core.database import get_db, SessionLocal
from api.websocket_routes import manager
from services.mindmap_service import mindmap_service

router = APIRouter(prefix="/api/mindmaps", tags=["mindmaps"])
logger = logging.getLogger(__name__)

async def run_mindmap_generation(mindmap_id: int, request: MindMapGenerateRequest, user_id: int):
    """Background task for mind map generation"""
    db = SessionLocal()
    connection_id = f"user_{user_id}"
    try:
        db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first()
        if not db_mindmap: return

        # Call AI service
        mermaid_code = await mindmap_service.generate_mindmap(
            file_key=request.file_key,
            text_input=request.text_input
        )

        if not mermaid_code:
            raise Exception("AI failed to generate mind map code")

        db_mindmap.mermaid_code = mermaid_code
        db_mindmap.status = "completed"
        db.commit()

        # Notify via WebSocket
        await manager.send_result(connection_id, {
            "type": "mindmap",
            "id": db_mindmap.id,
            "status": "completed",
            "title": db_mindmap.title
        })

    except Exception as e:
        logger.error(f"Background mindmap generation failed: {e}")
        db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first()
        if db_mindmap:
            db_mindmap.status = "failed"
            db_mindmap.error_message = str(e)
            db.commit()
        await manager.send_error(connection_id, f"Mind map generation failed: {str(e)}")
    finally:
        db.close()

@router.post("/generate", response_model=MindMapResponse)
async def generate_mindmap(
    request: MindMapGenerateRequest,
    background_tasks: BackgroundTasks,
    current_user: db_models.User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """
    Initiates mind map generation in the background.
    """
    source_id = None
    if request.file_key:
        source = db.query(db_models.Source).filter(
            db_models.Source.s3_key == request.file_key,
            db_models.Source.user_id == current_user.id
        ).first()
        if not source:
            raise HTTPException(status_code=403, detail="Not authorized to access this file")
        source_id = source.id

    # Create initial processing record
    file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None
    print(f'file_base: {file_base}')
    
    # Priority: 1. File-based name, 2. User Title (if not default 'string'), 3. Default timestamp
    if file_base:
        title = f"Mind Map-{file_base}"
    elif request.title and request.title != "string":
        title = request.title
    else:
        title = f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
    db_mindmap = db_models.MindMap(
        title=title,
        user_id=current_user.id,
        source_id=source_id,
        status="processing"
    )
    db.add(db_mindmap)
    db.commit()
    db.refresh(db_mindmap)

    # Offload to background task
    background_tasks.add_task(run_mindmap_generation, db_mindmap.id, request, current_user.id)

    # return processing state
    resp = MindMapResponse.model_validate(db_mindmap)
    resp.message = "Mind map generation started"
    return resp

@router.get("/list", response_model=List[MindMapResponse])
async def list_mindmaps(
    current_user: db_models.User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """
    Lists all mind maps for the current user.
    """
    try:
        mindmaps = db.query(db_models.MindMap).filter(
            db_models.MindMap.user_id == current_user.id
        ).order_by(db_models.MindMap.created_at.desc()).all()
        
        return [MindMapResponse.model_validate(m) for m in mindmaps]
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.delete("/{mindmap_id}")
async def delete_mindmap(
    mindmap_id: int,
    current_user: db_models.User = Depends(get_current_user),
    db: Session = Depends(get_db)
):
    """
    Deletes a specific mind map.
    """
    mindmap = db.query(db_models.MindMap).filter(
        db_models.MindMap.id == mindmap_id,
        db_models.MindMap.user_id == current_user.id
    ).first()
    
    if not mindmap:
        raise HTTPException(status_code=404, detail="Mind map not found")
    
    db.delete(mindmap)
    db.commit()
    return {"message": "Mind map deleted successfully"}