| | import logging |
| | from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks |
| | from sqlalchemy.orm import Session |
| | from typing import List |
| | from datetime import datetime |
| |
|
| | from api.auth import get_current_user |
| | from models import db_models |
| | from models.schemas import MindMapGenerateRequest, MindMapResponse |
| | from core.database import get_db, SessionLocal |
| | from api.websocket_routes import manager |
| | from services.mindmap_service import mindmap_service |
| |
|
| | router = APIRouter(prefix="/api/mindmaps", tags=["mindmaps"]) |
| | logger = logging.getLogger(__name__) |
| |
|
| | async def run_mindmap_generation(mindmap_id: int, request: MindMapGenerateRequest, user_id: int): |
| | """Background task for mind map generation""" |
| | db = SessionLocal() |
| | connection_id = f"user_{user_id}" |
| | try: |
| | db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first() |
| | if not db_mindmap: return |
| |
|
| | |
| | mermaid_code = await mindmap_service.generate_mindmap( |
| | file_key=request.file_key, |
| | text_input=request.text_input |
| | ) |
| |
|
| | if not mermaid_code: |
| | raise Exception("AI failed to generate mind map code") |
| |
|
| | db_mindmap.mermaid_code = mermaid_code |
| | db_mindmap.status = "completed" |
| | db.commit() |
| |
|
| | |
| | await manager.send_result(connection_id, { |
| | "type": "mindmap", |
| | "id": db_mindmap.id, |
| | "status": "completed", |
| | "title": db_mindmap.title |
| | }) |
| |
|
| | except Exception as e: |
| | logger.error(f"Background mindmap generation failed: {e}") |
| | db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first() |
| | if db_mindmap: |
| | db_mindmap.status = "failed" |
| | db_mindmap.error_message = str(e) |
| | db.commit() |
| | await manager.send_error(connection_id, f"Mind map generation failed: {str(e)}") |
| | finally: |
| | db.close() |
| |
|
| | @router.post("/generate", response_model=MindMapResponse) |
| | async def generate_mindmap( |
| | request: MindMapGenerateRequest, |
| | background_tasks: BackgroundTasks, |
| | current_user: db_models.User = Depends(get_current_user), |
| | db: Session = Depends(get_db) |
| | ): |
| | """ |
| | Initiates mind map generation in the background. |
| | """ |
| | source_id = None |
| | if request.file_key: |
| | source = db.query(db_models.Source).filter( |
| | db_models.Source.s3_key == request.file_key, |
| | db_models.Source.user_id == current_user.id |
| | ).first() |
| | if not source: |
| | raise HTTPException(status_code=403, detail="Not authorized to access this file") |
| | source_id = source.id |
| |
|
| | |
| | file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None |
| | print(f'file_base: {file_base}') |
| | |
| | |
| | if file_base: |
| | title = f"Mind Map-{file_base}" |
| | elif request.title and request.title != "string": |
| | title = request.title |
| | else: |
| | title = f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" |
| | db_mindmap = db_models.MindMap( |
| | title=title, |
| | user_id=current_user.id, |
| | source_id=source_id, |
| | status="processing" |
| | ) |
| | db.add(db_mindmap) |
| | db.commit() |
| | db.refresh(db_mindmap) |
| |
|
| | |
| | background_tasks.add_task(run_mindmap_generation, db_mindmap.id, request, current_user.id) |
| |
|
| | |
| | resp = MindMapResponse.model_validate(db_mindmap) |
| | resp.message = "Mind map generation started" |
| | return resp |
| |
|
| | @router.get("/list", response_model=List[MindMapResponse]) |
| | async def list_mindmaps( |
| | current_user: db_models.User = Depends(get_current_user), |
| | db: Session = Depends(get_db) |
| | ): |
| | """ |
| | Lists all mind maps for the current user. |
| | """ |
| | try: |
| | mindmaps = db.query(db_models.MindMap).filter( |
| | db_models.MindMap.user_id == current_user.id |
| | ).order_by(db_models.MindMap.created_at.desc()).all() |
| | |
| | return [MindMapResponse.model_validate(m) for m in mindmaps] |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @router.delete("/{mindmap_id}") |
| | async def delete_mindmap( |
| | mindmap_id: int, |
| | current_user: db_models.User = Depends(get_current_user), |
| | db: Session = Depends(get_db) |
| | ): |
| | """ |
| | Deletes a specific mind map. |
| | """ |
| | mindmap = db.query(db_models.MindMap).filter( |
| | db_models.MindMap.id == mindmap_id, |
| | db_models.MindMap.user_id == current_user.id |
| | ).first() |
| | |
| | if not mindmap: |
| | raise HTTPException(status_code=404, detail="Mind map not found") |
| | |
| | db.delete(mindmap) |
| | db.commit() |
| | return {"message": "Mind map deleted successfully"} |
| |
|