File size: 4,862 Bytes
792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 792ad00 b6e32c9 77f281e c8065f3 77f281e b6e32c9 792ad00 b6e32c9 792ad00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import logging
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List
from datetime import datetime
from api.auth import get_current_user
from models import db_models
from models.schemas import MindMapGenerateRequest, MindMapResponse
from core.database import get_db, SessionLocal
from api.websocket_routes import manager
from services.mindmap_service import mindmap_service
router = APIRouter(prefix="/api/mindmaps", tags=["mindmaps"])
logger = logging.getLogger(__name__)
async def run_mindmap_generation(mindmap_id: int, request: MindMapGenerateRequest, user_id: int):
"""Background task for mind map generation"""
db = SessionLocal()
connection_id = f"user_{user_id}"
try:
db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first()
if not db_mindmap: return
# Call AI service
mermaid_code = await mindmap_service.generate_mindmap(
file_key=request.file_key,
text_input=request.text_input
)
if not mermaid_code:
raise Exception("AI failed to generate mind map code")
db_mindmap.mermaid_code = mermaid_code
db_mindmap.status = "completed"
db.commit()
# Notify via WebSocket
await manager.send_result(connection_id, {
"type": "mindmap",
"id": db_mindmap.id,
"status": "completed",
"title": db_mindmap.title
})
except Exception as e:
logger.error(f"Background mindmap generation failed: {e}")
db_mindmap = db.query(db_models.MindMap).filter(db_models.MindMap.id == mindmap_id).first()
if db_mindmap:
db_mindmap.status = "failed"
db_mindmap.error_message = str(e)
db.commit()
await manager.send_error(connection_id, f"Mind map generation failed: {str(e)}")
finally:
db.close()
@router.post("/generate", response_model=MindMapResponse)
async def generate_mindmap(
request: MindMapGenerateRequest,
background_tasks: BackgroundTasks,
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Initiates mind map generation in the background.
"""
source_id = None
if request.file_key:
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == request.file_key,
db_models.Source.user_id == current_user.id
).first()
if not source:
raise HTTPException(status_code=403, detail="Not authorized to access this file")
source_id = source.id
# Create initial processing record
file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None
print(f'file_base: {file_base}')
# Priority: 1. File-based name, 2. User Title (if not default 'string'), 3. Default timestamp
if file_base:
title = f"Mind Map-{file_base}"
elif request.title and request.title != "string":
title = request.title
else:
title = f"Mind Map {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_mindmap = db_models.MindMap(
title=title,
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_mindmap)
db.commit()
db.refresh(db_mindmap)
# Offload to background task
background_tasks.add_task(run_mindmap_generation, db_mindmap.id, request, current_user.id)
# return processing state
resp = MindMapResponse.model_validate(db_mindmap)
resp.message = "Mind map generation started"
return resp
@router.get("/list", response_model=List[MindMapResponse])
async def list_mindmaps(
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Lists all mind maps for the current user.
"""
try:
mindmaps = db.query(db_models.MindMap).filter(
db_models.MindMap.user_id == current_user.id
).order_by(db_models.MindMap.created_at.desc()).all()
return [MindMapResponse.model_validate(m) for m in mindmaps]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{mindmap_id}")
async def delete_mindmap(
mindmap_id: int,
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Deletes a specific mind map.
"""
mindmap = db.query(db_models.MindMap).filter(
db_models.MindMap.id == mindmap_id,
db_models.MindMap.user_id == current_user.id
).first()
if not mindmap:
raise HTTPException(status_code=404, detail="Mind map not found")
db.delete(mindmap)
db.commit()
return {"message": "Mind map deleted successfully"}
|