matsuap's picture
Upload folder using huggingface_hub
ce4fba9 verified
import logging
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List, Optional
from datetime import datetime
from api.auth import get_current_user
from models import db_models
from models.schemas import ReportGenerateRequest, ReportResponse, ReportFormatSuggestionResponse
from core.database import get_db, SessionLocal
from api.websocket_routes import manager
from services.report_service import report_service
from core import constants
router = APIRouter(prefix="/api/reports", tags=["reports"])
logger = logging.getLogger(__name__)
async def run_report_generation(report_id: int, request: ReportGenerateRequest, user_id: int):
"""Background task for report generation"""
db = SessionLocal()
connection_id = f"user_{user_id}"
try:
db_report = db.query(db_models.Report).filter(db_models.Report.id == report_id).first()
if not db_report: return
# Call AI service
content = await report_service.generate_report(
file_key=request.file_key,
text_input=request.text_input,
format_key=request.format_key,
custom_prompt=request.custom_prompt,
language=request.language
)
if not content:
raise Exception("AI failed to generate report content")
if not db_report.title or "Report-" not in db_report.title:
# Extract title (usually the first line)
extracted_title = content.split('\n')[0].replace('#', '').strip()
if not extracted_title or len(extracted_title) < 3:
extracted_title = f"Report {request.format_key}"
db_report.title = extracted_title
db_report.content = content
db_report.status = "completed"
db.commit()
# Notify via WebSocket
await manager.send_result(connection_id, {
"type": "report",
"id": db_report.id,
"status": "completed",
"title": db_report.title
})
except Exception as e:
logger.error(f"Background report generation failed: {e}")
db_report = db.query(db_models.Report).filter(db_models.Report.id == report_id).first()
if db_report:
db_report.status = "failed"
db_report.error_message = str(e)
db.commit()
await manager.send_error(connection_id, f"Report generation failed: {str(e)}")
finally:
db.close()
@router.get("/config")
async def get_report_config():
"""Returns available formats and languages for report generation."""
return {
"formats": constants.REPORT_FORMAT_OPTIONS,
"languages": constants.LANGUAGES
}
@router.get("/suggest-formats", response_model=ReportFormatSuggestionResponse)
async def suggest_formats(
file_key: Optional[str] = None,
text_input: Optional[str] = None,
language: str = "Japanese",
current_user: db_models.User = Depends(get_current_user)):
"""
Get 4 AI-suggested report formats based on content.
"""
suggestions = await report_service.generate_format_suggestions(
file_key=file_key,
text_input=text_input,
language=language
)
return {"suggestions": suggestions}
@router.post("/generate", response_model=ReportResponse)
async def generate_report(
request: ReportGenerateRequest,
background_tasks: BackgroundTasks,
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Initiates report generation in the background.
"""
source_id = None
if request.file_key:
source = db.query(db_models.Source).filter(
db_models.Source.s3_key == request.file_key,
db_models.Source.user_id == current_user.id
).first()
if not source:
raise HTTPException(status_code=403, detail="Not authorized to access this file")
source_id = source.id
# Create initial processing record
file_base = request.file_key.split('/')[-1].rsplit('.', 1)[0] if request.file_key else None
title = f"Report-{file_base}" if file_base else f"Report {request.format_key} {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
db_report = db_models.Report(
title=title,
format_key=request.format_key,
user_id=current_user.id,
source_id=source_id,
status="processing"
)
db.add(db_report)
db.commit()
db.refresh(db_report)
# Offload to background task
background_tasks.add_task(run_report_generation, db_report.id, request, current_user.id)
return db_report
@router.get("/list", response_model=List[ReportResponse])
async def list_reports(
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Lists all reports for the current user.
"""
try:
reports = db.query(db_models.Report).filter(
db_models.Report.user_id == current_user.id
).order_by(db_models.Report.created_at.desc()).all()
return [ReportResponse.model_validate(r) for r in reports]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{report_id}", response_model=ReportResponse)
async def get_report(
report_id: int,
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieves a specific report.
"""
report = db.query(db_models.Report).filter(
db_models.Report.id == report_id,
db_models.Report.user_id == current_user.id
).first()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
return ReportResponse.model_validate(report)
@router.delete("/{report_id}")
async def delete_report(
report_id: int,
current_user: db_models.User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Deletes a specific report.
"""
report = db.query(db_models.Report).filter(
db_models.Report.id == report_id,
db_models.Report.user_id == current_user.id
).first()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
db.delete(report)
db.commit()
return {"message": "Report deleted successfully"}