Spaces:
Sleeping
Sleeping
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| import re | |
| import json | |
| from fastapi import FastAPI, Depends, HTTPException, Query, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import StreamingResponse | |
| from pathlib import Path | |
| from sqlmodel import Session, select | |
| from sqlalchemy import or_ | |
| from app.database import create_db_and_tables, ensure_appsettings_schema, ensure_paper_schema, get_session | |
| from app.models import Paper, PaperRead, AppSettings | |
| from app.services.arxiv_bot import get_arxiv_bot, run_daily_fetch | |
| from app.services.dify_client import ( | |
| get_dify_client, | |
| DifyClientError, | |
| DifyEntityTooLargeError, | |
| DifyTimeoutError, | |
| DifyRateLimitError, | |
| ) | |
| from app.constants import ARXIV_OPTIONS | |
| async def lifespan(app: FastAPI): | |
| create_db_and_tables() | |
| ensure_appsettings_schema() | |
| ensure_paper_schema() | |
| yield | |
| app = FastAPI( | |
| title="Paper Insight API", | |
| description="API for fetching and summarizing arXiv papers focused on Autoregressive DiT and KV Cache Compression", | |
| version="0.1.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| static_path = Path(__file__).parent / "static" | |
| static_path.mkdir(exist_ok=True) | |
| app.mount("/static", StaticFiles(directory=static_path), name="static") | |
| def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy"} | |
| def get_constants(): | |
| """Get application constants.""" | |
| return {"arxiv_options": ARXIV_OPTIONS} | |
| def get_settings(session: Session = Depends(get_session)): | |
| """Get application settings.""" | |
| settings = session.get(AppSettings, 1) | |
| if not settings: | |
| settings = AppSettings(id=1) | |
| session.add(settings) | |
| session.commit() | |
| session.refresh(settings) | |
| return settings | |
| def update_settings(new_settings: AppSettings, session: Session = Depends(get_session)): | |
| """Update application settings.""" | |
| settings = session.get(AppSettings, 1) | |
| if not settings: | |
| settings = AppSettings(id=1) | |
| session.add(settings) | |
| settings.research_focus = new_settings.research_focus | |
| settings.system_prompt = new_settings.system_prompt | |
| settings.arxiv_categories = new_settings.arxiv_categories | |
| # Parse focus keywords | |
| if new_settings.research_focus: | |
| raw_focus = new_settings.research_focus.strip() | |
| if ";" in raw_focus: | |
| keywords = [ | |
| k.strip() for k in re.split(r"[;]+", raw_focus) | |
| if k.strip() | |
| ] | |
| else: | |
| parts = re.split(r"\bOR\b|\bAND\b", raw_focus, flags=re.IGNORECASE) | |
| keywords = [] | |
| for part in parts: | |
| cleaned = part.strip() | |
| if not cleaned: | |
| continue | |
| cleaned = re.sub(r"^[()]+|[()]+$", "", cleaned).strip() | |
| cleaned = re.sub(r"^(?:all|abs|ti):", "", cleaned, flags=re.IGNORECASE).strip() | |
| cleaned = cleaned.strip('"').strip() | |
| if cleaned: | |
| keywords.append(cleaned) | |
| seen = set() | |
| deduped = [] | |
| for keyword in keywords: | |
| if keyword not in seen: | |
| deduped.append(keyword) | |
| seen.add(keyword) | |
| keywords = deduped | |
| settings.focus_keywords = keywords | |
| else: | |
| settings.focus_keywords = [] | |
| session.add(settings) | |
| session.commit() | |
| session.refresh(settings) | |
| return settings | |
| def get_papers( | |
| session: Session = Depends(get_session), | |
| skip: int = Query(0, ge=0), | |
| limit: int = Query(20, ge=1, le=100), | |
| min_score: Optional[float] = Query(None, ge=0, le=10), | |
| processed_only: bool = Query(False), | |
| ): | |
| """Get papers with optional filtering.""" | |
| query = select(Paper).where( | |
| or_(Paper.processing_status.is_(None), Paper.processing_status != "skipped") | |
| ) | |
| if processed_only: | |
| query = query.where(Paper.is_processed == True) | |
| if min_score is not None: | |
| query = query.where(Paper.relevance_score >= min_score) | |
| query = query.order_by( | |
| Paper.is_processed.desc(), | |
| Paper.relevance_score.desc().nulls_last(), | |
| Paper.published.desc() | |
| ).offset(skip).limit(limit) | |
| papers = session.exec(query).all() | |
| return papers | |
| def get_paper(paper_id: int, session: Session = Depends(get_session)): | |
| """Get a specific paper by ID.""" | |
| paper = session.get(Paper, paper_id) | |
| if not paper: | |
| raise HTTPException(status_code=404, detail="Paper not found") | |
| return paper | |
| def get_paper_by_arxiv_id(arxiv_id: str, session: Session = Depends(get_session)): | |
| """Get a specific paper by arXiv ID.""" | |
| paper = session.exec(select(Paper).where(Paper.arxiv_id == arxiv_id)).first() | |
| if not paper: | |
| raise HTTPException(status_code=404, detail="Paper not found") | |
| return paper | |
| def fetch_papers( | |
| background_tasks: BackgroundTasks, | |
| session: Session = Depends(get_session), | |
| ): | |
| """Trigger paper fetching in the background.""" | |
| background_tasks.add_task(run_daily_fetch) | |
| return {"message": "Paper fetch started in background"} | |
| async def process_paper(paper_id: int, session: Session = Depends(get_session)): | |
| """Process a specific paper with LLM analysis.""" | |
| paper = session.get(Paper, paper_id) | |
| if not paper: | |
| raise HTTPException(status_code=404, detail="Paper not found") | |
| if paper.is_processed: | |
| return {"message": "Paper already processed", "paper_id": paper_id} | |
| bot = get_arxiv_bot() | |
| success = await bot.process_paper(session, paper) | |
| if success: | |
| return {"message": "Paper processed successfully", "paper_id": paper_id} | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to process paper") | |
| async def process_paper_stream(paper_id: int, session: Session = Depends(get_session)): | |
| """ | |
| Process a paper with streaming response for real-time updates. | |
| Returns Server-Sent Events (SSE) with the following event types: | |
| - thinking: R1 reasoning process (thought field) | |
| - answer: Partial answer content | |
| - progress: Processing progress updates | |
| - result: Final structured analysis result | |
| - error: Error information | |
| - done: Stream completion signal | |
| """ | |
| paper = session.get(Paper, paper_id) | |
| if not paper: | |
| raise HTTPException(status_code=404, detail="Paper not found") | |
| async def generate_events(): | |
| """Generate SSE events for paper analysis.""" | |
| try: | |
| # Update paper status | |
| paper.processing_status = "processing" | |
| session.add(paper) | |
| session.commit() | |
| # Send initial progress event | |
| yield f"event: progress\ndata: {json.dumps({'status': 'started', 'message': '开始分析论文...'})}\n\n" | |
| dify_client = get_dify_client() | |
| thought_parts = [] | |
| answer_parts = [] | |
| final_outputs = None | |
| async for event in dify_client.analyze_paper_stream( | |
| paper.title, | |
| paper.abstract, | |
| user_id=f"paper-{paper_id}", | |
| ): | |
| # Handle thought (R1 thinking process) | |
| if event.thought: | |
| thought_parts.append(event.thought) | |
| yield f"event: thinking\ndata: {json.dumps({'thought': event.thought})}\n\n" | |
| # Handle answer chunks | |
| if event.answer: | |
| answer_parts.append(event.answer) | |
| yield f"event: answer\ndata: {json.dumps({'answer': event.answer})}\n\n" | |
| # Handle workflow events | |
| if event.event == "workflow_started": | |
| yield f"event: progress\ndata: {json.dumps({'status': 'workflow_started', 'message': 'Dify工作流已启动'})}\n\n" | |
| elif event.event == "node_started": | |
| node_title = event.data.get("data", {}).get("title", "") | |
| if node_title: | |
| yield f"event: progress\ndata: {json.dumps({'status': 'node_started', 'message': f'执行节点: {node_title}'})}\n\n" | |
| elif event.event == "workflow_finished": | |
| if event.outputs: | |
| final_outputs = event.outputs | |
| # Process final result | |
| if final_outputs: | |
| result = dify_client._parse_outputs(final_outputs, "".join(thought_parts)) | |
| elif answer_parts: | |
| result = dify_client._parse_answer("".join(answer_parts), "".join(thought_parts)) | |
| else: | |
| raise DifyClientError("No output received from Dify workflow") | |
| # Convert to LLMAnalysis for database storage | |
| analysis = dify_client.to_llm_analysis(result) | |
| # Update paper with results | |
| from datetime import datetime | |
| paper.summary_zh = analysis.summary_zh | |
| paper.relevance_score = analysis.relevance_score | |
| paper.relevance_reason = analysis.relevance_reason | |
| paper.heuristic_idea = analysis.heuristic_idea | |
| paper.is_processed = True | |
| paper.processed_at = datetime.utcnow() | |
| if analysis.relevance_score >= 5: | |
| paper.processing_status = "processed" | |
| else: | |
| paper.processing_status = "skipped" | |
| session.add(paper) | |
| session.commit() | |
| # Send final result | |
| result_data = { | |
| "summary_zh": result.summary_zh, | |
| "relevance_score": result.relevance_score, | |
| "relevance_reason": result.relevance_reason, | |
| "technical_mapping": { | |
| "token_vs_patch": result.technical_mapping.token_vs_patch, | |
| "temporal_logic": result.technical_mapping.temporal_logic, | |
| "frequency_domain": result.technical_mapping.frequency_domain, | |
| }, | |
| "heuristic_idea": result.heuristic_idea, | |
| "thought_process": result.thought_process, | |
| } | |
| yield f"event: result\ndata: {json.dumps(result_data, ensure_ascii=False)}\n\n" | |
| yield f"event: done\ndata: {json.dumps({'status': 'completed'})}\n\n" | |
| except DifyEntityTooLargeError as e: | |
| paper.processing_status = "failed" | |
| session.add(paper) | |
| session.commit() | |
| yield f"event: error\ndata: {json.dumps({'error': 'entity_too_large', 'message': str(e)})}\n\n" | |
| except DifyTimeoutError as e: | |
| paper.processing_status = "failed" | |
| session.add(paper) | |
| session.commit() | |
| yield f"event: error\ndata: {json.dumps({'error': 'timeout', 'message': str(e)})}\n\n" | |
| except DifyRateLimitError as e: | |
| paper.processing_status = "failed" | |
| session.add(paper) | |
| session.commit() | |
| yield f"event: error\ndata: {json.dumps({'error': 'rate_limit', 'message': str(e)})}\n\n" | |
| except DifyClientError as e: | |
| paper.processing_status = "failed" | |
| session.add(paper) | |
| session.commit() | |
| yield f"event: error\ndata: {json.dumps({'error': 'dify_error', 'message': str(e)})}\n\n" | |
| except Exception as e: | |
| paper.processing_status = "failed" | |
| session.add(paper) | |
| session.commit() | |
| yield f"event: error\ndata: {json.dumps({'error': 'unknown', 'message': str(e)})}\n\n" | |
| return StreamingResponse( | |
| generate_events(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", # Disable nginx buffering | |
| }, | |
| ) | |
| def get_stats(session: Session = Depends(get_session)): | |
| """Get statistics about papers.""" | |
| total = session.exec( | |
| select(Paper).where( | |
| or_(Paper.processing_status.is_(None), Paper.processing_status != "skipped") | |
| ) | |
| ).all() | |
| processed = [p for p in total if p.is_processed] | |
| high_relevance = [p for p in processed if p.relevance_score and p.relevance_score >= 9] | |
| return { | |
| "total_papers": len(total), | |
| "processed_papers": len(processed), | |
| "high_relevance_papers": len(high_relevance), | |
| "pending_processing": len(total) - len(processed), | |
| } | |