winbeau's picture
Backend snapshot from 760fa78
353b9f4
from contextlib import asynccontextmanager
from typing import List, Optional
import re
import json
from fastapi import FastAPI, Depends, HTTPException, Query, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import StreamingResponse
from pathlib import Path
from sqlmodel import Session, select
from sqlalchemy import or_
from app.database import create_db_and_tables, ensure_appsettings_schema, ensure_paper_schema, get_session
from app.models import Paper, PaperRead, AppSettings
from app.services.arxiv_bot import get_arxiv_bot, run_daily_fetch
from app.services.dify_client import (
get_dify_client,
DifyClientError,
DifyEntityTooLargeError,
DifyTimeoutError,
DifyRateLimitError,
)
from app.constants import ARXIV_OPTIONS
@asynccontextmanager
async def lifespan(app: FastAPI):
create_db_and_tables()
ensure_appsettings_schema()
ensure_paper_schema()
yield
app = FastAPI(
title="Paper Insight API",
description="API for fetching and summarizing arXiv papers focused on Autoregressive DiT and KV Cache Compression",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
static_path = Path(__file__).parent / "static"
static_path.mkdir(exist_ok=True)
app.mount("/static", StaticFiles(directory=static_path), name="static")
@app.get("/health")
def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/constants")
def get_constants():
"""Get application constants."""
return {"arxiv_options": ARXIV_OPTIONS}
@app.get("/settings", response_model=AppSettings)
def get_settings(session: Session = Depends(get_session)):
"""Get application settings."""
settings = session.get(AppSettings, 1)
if not settings:
settings = AppSettings(id=1)
session.add(settings)
session.commit()
session.refresh(settings)
return settings
@app.put("/settings", response_model=AppSettings)
def update_settings(new_settings: AppSettings, session: Session = Depends(get_session)):
"""Update application settings."""
settings = session.get(AppSettings, 1)
if not settings:
settings = AppSettings(id=1)
session.add(settings)
settings.research_focus = new_settings.research_focus
settings.system_prompt = new_settings.system_prompt
settings.arxiv_categories = new_settings.arxiv_categories
# Parse focus keywords
if new_settings.research_focus:
raw_focus = new_settings.research_focus.strip()
if ";" in raw_focus:
keywords = [
k.strip() for k in re.split(r"[;]+", raw_focus)
if k.strip()
]
else:
parts = re.split(r"\bOR\b|\bAND\b", raw_focus, flags=re.IGNORECASE)
keywords = []
for part in parts:
cleaned = part.strip()
if not cleaned:
continue
cleaned = re.sub(r"^[()]+|[()]+$", "", cleaned).strip()
cleaned = re.sub(r"^(?:all|abs|ti):", "", cleaned, flags=re.IGNORECASE).strip()
cleaned = cleaned.strip('"').strip()
if cleaned:
keywords.append(cleaned)
seen = set()
deduped = []
for keyword in keywords:
if keyword not in seen:
deduped.append(keyword)
seen.add(keyword)
keywords = deduped
settings.focus_keywords = keywords
else:
settings.focus_keywords = []
session.add(settings)
session.commit()
session.refresh(settings)
return settings
@app.get("/papers", response_model=List[PaperRead])
def get_papers(
session: Session = Depends(get_session),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
min_score: Optional[float] = Query(None, ge=0, le=10),
processed_only: bool = Query(False),
):
"""Get papers with optional filtering."""
query = select(Paper).where(
or_(Paper.processing_status.is_(None), Paper.processing_status != "skipped")
)
if processed_only:
query = query.where(Paper.is_processed == True)
if min_score is not None:
query = query.where(Paper.relevance_score >= min_score)
query = query.order_by(
Paper.is_processed.desc(),
Paper.relevance_score.desc().nulls_last(),
Paper.published.desc()
).offset(skip).limit(limit)
papers = session.exec(query).all()
return papers
@app.get("/papers/{paper_id}", response_model=PaperRead)
def get_paper(paper_id: int, session: Session = Depends(get_session)):
"""Get a specific paper by ID."""
paper = session.get(Paper, paper_id)
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
return paper
@app.get("/papers/arxiv/{arxiv_id}", response_model=PaperRead)
def get_paper_by_arxiv_id(arxiv_id: str, session: Session = Depends(get_session)):
"""Get a specific paper by arXiv ID."""
paper = session.exec(select(Paper).where(Paper.arxiv_id == arxiv_id)).first()
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
return paper
@app.post("/papers/fetch")
def fetch_papers(
background_tasks: BackgroundTasks,
session: Session = Depends(get_session),
):
"""Trigger paper fetching in the background."""
background_tasks.add_task(run_daily_fetch)
return {"message": "Paper fetch started in background"}
@app.post("/papers/{paper_id}/process")
async def process_paper(paper_id: int, session: Session = Depends(get_session)):
"""Process a specific paper with LLM analysis."""
paper = session.get(Paper, paper_id)
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
if paper.is_processed:
return {"message": "Paper already processed", "paper_id": paper_id}
bot = get_arxiv_bot()
success = await bot.process_paper(session, paper)
if success:
return {"message": "Paper processed successfully", "paper_id": paper_id}
else:
raise HTTPException(status_code=500, detail="Failed to process paper")
@app.get("/papers/{paper_id}/process/stream")
async def process_paper_stream(paper_id: int, session: Session = Depends(get_session)):
"""
Process a paper with streaming response for real-time updates.
Returns Server-Sent Events (SSE) with the following event types:
- thinking: R1 reasoning process (thought field)
- answer: Partial answer content
- progress: Processing progress updates
- result: Final structured analysis result
- error: Error information
- done: Stream completion signal
"""
paper = session.get(Paper, paper_id)
if not paper:
raise HTTPException(status_code=404, detail="Paper not found")
async def generate_events():
"""Generate SSE events for paper analysis."""
try:
# Update paper status
paper.processing_status = "processing"
session.add(paper)
session.commit()
# Send initial progress event
yield f"event: progress\ndata: {json.dumps({'status': 'started', 'message': '开始分析论文...'})}\n\n"
dify_client = get_dify_client()
thought_parts = []
answer_parts = []
final_outputs = None
async for event in dify_client.analyze_paper_stream(
paper.title,
paper.abstract,
user_id=f"paper-{paper_id}",
):
# Handle thought (R1 thinking process)
if event.thought:
thought_parts.append(event.thought)
yield f"event: thinking\ndata: {json.dumps({'thought': event.thought})}\n\n"
# Handle answer chunks
if event.answer:
answer_parts.append(event.answer)
yield f"event: answer\ndata: {json.dumps({'answer': event.answer})}\n\n"
# Handle workflow events
if event.event == "workflow_started":
yield f"event: progress\ndata: {json.dumps({'status': 'workflow_started', 'message': 'Dify工作流已启动'})}\n\n"
elif event.event == "node_started":
node_title = event.data.get("data", {}).get("title", "")
if node_title:
yield f"event: progress\ndata: {json.dumps({'status': 'node_started', 'message': f'执行节点: {node_title}'})}\n\n"
elif event.event == "workflow_finished":
if event.outputs:
final_outputs = event.outputs
# Process final result
if final_outputs:
result = dify_client._parse_outputs(final_outputs, "".join(thought_parts))
elif answer_parts:
result = dify_client._parse_answer("".join(answer_parts), "".join(thought_parts))
else:
raise DifyClientError("No output received from Dify workflow")
# Convert to LLMAnalysis for database storage
analysis = dify_client.to_llm_analysis(result)
# Update paper with results
from datetime import datetime
paper.summary_zh = analysis.summary_zh
paper.relevance_score = analysis.relevance_score
paper.relevance_reason = analysis.relevance_reason
paper.heuristic_idea = analysis.heuristic_idea
paper.is_processed = True
paper.processed_at = datetime.utcnow()
if analysis.relevance_score >= 5:
paper.processing_status = "processed"
else:
paper.processing_status = "skipped"
session.add(paper)
session.commit()
# Send final result
result_data = {
"summary_zh": result.summary_zh,
"relevance_score": result.relevance_score,
"relevance_reason": result.relevance_reason,
"technical_mapping": {
"token_vs_patch": result.technical_mapping.token_vs_patch,
"temporal_logic": result.technical_mapping.temporal_logic,
"frequency_domain": result.technical_mapping.frequency_domain,
},
"heuristic_idea": result.heuristic_idea,
"thought_process": result.thought_process,
}
yield f"event: result\ndata: {json.dumps(result_data, ensure_ascii=False)}\n\n"
yield f"event: done\ndata: {json.dumps({'status': 'completed'})}\n\n"
except DifyEntityTooLargeError as e:
paper.processing_status = "failed"
session.add(paper)
session.commit()
yield f"event: error\ndata: {json.dumps({'error': 'entity_too_large', 'message': str(e)})}\n\n"
except DifyTimeoutError as e:
paper.processing_status = "failed"
session.add(paper)
session.commit()
yield f"event: error\ndata: {json.dumps({'error': 'timeout', 'message': str(e)})}\n\n"
except DifyRateLimitError as e:
paper.processing_status = "failed"
session.add(paper)
session.commit()
yield f"event: error\ndata: {json.dumps({'error': 'rate_limit', 'message': str(e)})}\n\n"
except DifyClientError as e:
paper.processing_status = "failed"
session.add(paper)
session.commit()
yield f"event: error\ndata: {json.dumps({'error': 'dify_error', 'message': str(e)})}\n\n"
except Exception as e:
paper.processing_status = "failed"
session.add(paper)
session.commit()
yield f"event: error\ndata: {json.dumps({'error': 'unknown', 'message': str(e)})}\n\n"
return StreamingResponse(
generate_events(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
@app.get("/stats")
def get_stats(session: Session = Depends(get_session)):
"""Get statistics about papers."""
total = session.exec(
select(Paper).where(
or_(Paper.processing_status.is_(None), Paper.processing_status != "skipped")
)
).all()
processed = [p for p in total if p.is_processed]
high_relevance = [p for p in processed if p.relevance_score and p.relevance_score >= 9]
return {
"total_papers": len(total),
"processed_papers": len(processed),
"high_relevance_papers": len(high_relevance),
"pending_processing": len(total) - len(processed),
}