Spaces:
Running
Running
| # 文件路径: app/main.py | |
| import sys | |
| import io | |
| import os | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| # 强制 stdout 使用 utf-8,防止 Windows 控制台乱码 | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from sse_starlette.sse import EventSourceResponse | |
| from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import uvicorn | |
| from app.core.config import settings | |
| from app.services.agent_service import agent_stream | |
| from app.services.chat_service import process_chat_stream, get_eval_data, clear_eval_data | |
| from app.services.vector_service import store_manager | |
| from app.services.auto_evaluation_service import ( | |
| init_auto_evaluation_service, | |
| get_auto_evaluation_service, | |
| EvaluationConfig | |
| ) | |
| from evaluation.evaluation_framework import EvaluationEngine, EvaluationResult, DataRoutingEngine | |
| from datetime import datetime | |
| import uuid | |
| settings.validate() | |
| # === 生命周期管理 === | |
| async def lifespan(app: FastAPI): | |
| """应用生命周期管理""" | |
| from app.services.vector_service import store_manager | |
| # 启动时运行 | |
| print("🚀 Application starting...") | |
| # 仓库数据永久存储,对话记忆纯内存存储(重启自动清空) | |
| yield | |
| # 关闭时运行 | |
| print("🛑 Application shutting down...") | |
| # 清理 GitHub 客户端连接 | |
| from app.utils.github_client import close_github_client | |
| await close_github_client() | |
| # 清理向量存储连接 | |
| await store_manager.close_all() | |
| # 关闭共享的 Qdrant 客户端 | |
| from app.storage.qdrant_store import close_shared_client | |
| await close_shared_client() | |
| print("✅ Cleanup complete") | |
| app = FastAPI(title="GitHub RAG Agent", lifespan=lifespan) | |
| # === 初始化评估引擎 === | |
| from app.utils.llm_client import client | |
| eval_engine = EvaluationEngine(llm_client=client, model_name=settings.default_model_name) | |
| data_router = DataRoutingEngine() | |
| # === 初始化自动评估服务 (Phase 1) === | |
| auto_eval_config = EvaluationConfig( | |
| enabled=True, | |
| use_ragas=False, # Phase 1: 先不用 Ragas,避免额外依赖 | |
| async_evaluation=True, # 异步模式,不阻塞响应 | |
| min_quality_score=0.4, # 最低分数阈值(0.4 = 只拒绝最差的) | |
| min_query_length=10, # 最小 query 长度 | |
| min_answer_length=100, # 最小 answer 长度 | |
| require_repo_url=True, # 必须有仓库 URL | |
| require_code_in_context=True # 上下文必须包含代码 | |
| ) | |
| auto_eval_service = init_auto_evaluation_service( | |
| eval_engine=eval_engine, | |
| data_router=data_router, | |
| config=auto_eval_config | |
| ) | |
| print("✅ Auto Evaluation Service Initialized") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # === 静态文件与前端 === | |
| app.mount("/static", StaticFiles(directory="app"), name="static") | |
| # Vue 3 构建输出的静态资源 (JS/CSS/assets) | |
| import os | |
| FRONTEND_DIST = os.path.join(os.path.dirname(os.path.dirname(__file__)), "frontend-dist") | |
| if os.path.exists(FRONTEND_DIST): | |
| app.mount("/assets", StaticFiles(directory=os.path.join(FRONTEND_DIST, "assets")), name="vue-assets") | |
| async def read_root(): | |
| # 优先使用 Vue 3 构建版本,否则回退到原版 | |
| vue_index = os.path.join(FRONTEND_DIST, "index.html") | |
| if os.path.exists(vue_index): | |
| with open(vue_index, "r", encoding="utf-8") as f: | |
| return f.read() | |
| # 回退到原版前端 | |
| with open("frontend/index.html", "r", encoding="utf-8") as f: | |
| return f.read() | |
| def health_check(): | |
| return {"status": "ok"} | |
| async def get_sessions(): | |
| """获取 session 管理状态""" | |
| return JSONResponse(store_manager.get_stats()) | |
| async def trigger_cleanup(): | |
| """手动触发过期文件清理""" | |
| stats = await store_manager.cleanup_expired_files() | |
| return JSONResponse({"message": "Cleanup completed", "stats": stats}) | |
| async def close_session(session_id: str): | |
| """关闭指定 session""" | |
| await store_manager.close_session(session_id) | |
| return JSONResponse({"message": f"Session {session_id} closed"}) | |
| # === 仓库级 Session API === | |
| async def check_repo_session(request: Request): | |
| """ | |
| 检查仓库是否已有指定语言的索引和报告 | |
| 请求: { "url": "https://github.com/owner/repo", "language": "zh" } | |
| 响应: { | |
| "exists": true/false, | |
| "session_id": "repo_xxx", | |
| "report": "..." (如果存在对应语言的报告), | |
| "has_index": true/false, | |
| "available_languages": ["en", "zh"] | |
| } | |
| """ | |
| from app.utils.session import generate_repo_session_id | |
| data = await request.json() | |
| repo_url = data.get("url", "").strip() | |
| language = data.get("language", "en") | |
| if not repo_url: | |
| return JSONResponse({"error": "Missing URL"}, status_code=400) | |
| # 生成基于仓库的 Session ID | |
| session_id = generate_repo_session_id(repo_url) | |
| # 检查是否存在 | |
| store = store_manager.get_store(session_id) | |
| # 尝试加载上下文 | |
| context = store.load_context() | |
| if context and context.get("repo_url"): | |
| # 存在已分析的仓库 | |
| # 获取指定语言的报告 | |
| report = store.get_report(language) | |
| available_languages = store.get_available_languages() | |
| global_context = context.get("global_context", {}) | |
| has_index = bool(global_context.get("file_tree")) | |
| return JSONResponse({ | |
| "exists": True, | |
| "session_id": session_id, | |
| "repo_url": context.get("repo_url"), | |
| "report": report, # 指定语言的报告,可能为 None | |
| "has_index": has_index, | |
| "available_languages": available_languages, | |
| "requested_language": language, | |
| }) | |
| else: | |
| return JSONResponse({ | |
| "exists": False, | |
| "session_id": session_id, | |
| "has_index": False, | |
| "available_languages": [], | |
| }) | |
| async def analyze(url: str, session_id: str, language: str = "en", regenerate_only: bool = False): | |
| """ | |
| 仓库分析端点 | |
| Args: | |
| url: 仓库 URL | |
| session_id: Session ID | |
| language: 报告语言 ("en" 或 "zh") | |
| regenerate_only: True 时跳过抓取/索引,直接使用已有索引生成新语言报告 | |
| """ | |
| if not session_id: | |
| return {"error": "Missing session_id"} | |
| return EventSourceResponse(agent_stream(url, session_id, language, regenerate_only)) | |
| async def chat(request: Request): | |
| """ | |
| 聊天端点 - 自动评估版本 | |
| 改进点: | |
| 1. 立即返回聊天结果(不阻塞) | |
| 2. 后台异步进行自动评估 | |
| 3. 评估结果自动存储到 evaluation/sft_data/ | |
| """ | |
| data = await request.json() | |
| user_query = data.get("query") | |
| session_id = data.get("session_id") | |
| repo_url = data.get("repo_url", "") | |
| if not user_query: | |
| return {"answer": "Please enter your question"} | |
| if not session_id: | |
| return {"answer": "Session lost"} | |
| # 标记流是否完成 | |
| stream_completed = False | |
| async def chat_stream_with_eval(): | |
| """包装 process_chat_stream,流结束后触发评估""" | |
| nonlocal stream_completed | |
| # 清除旧的评估数据 | |
| clear_eval_data(session_id) | |
| # 执行聊天流 | |
| async for chunk in process_chat_stream(user_query, session_id): | |
| yield chunk | |
| # 流完成后标记 | |
| stream_completed = True | |
| # 流结束后触发评估(此时数据已存储在 chat_service 中) | |
| try: | |
| auto_eval_service = get_auto_evaluation_service() | |
| eval_data = get_eval_data(session_id) | |
| if auto_eval_service and eval_data and eval_data.answer: | |
| print(f"\n📊 [Auto-Eval] Starting evaluation for session {session_id}") | |
| print(f" - Query: {user_query[:50]}...") | |
| print(f" - Context length: {len(eval_data.retrieved_context)} chars") | |
| print(f" - Answer length: {len(eval_data.answer)} chars") | |
| # 异步执行评估(不阻塞流结束) | |
| asyncio.create_task( | |
| auto_eval_service.auto_evaluate_async( | |
| query=user_query, | |
| retrieved_context=eval_data.retrieved_context, | |
| generated_answer=eval_data.answer, | |
| session_id=session_id, | |
| repo_url=repo_url, | |
| language="zh" if any('\u4e00' <= c <= '\u9fff' for c in user_query) else "en" | |
| ) | |
| ) | |
| else: | |
| if not auto_eval_service: | |
| print("⚠️ Auto evaluation service not initialized") | |
| elif not eval_data: | |
| print(f"⚠️ No eval data found for session {session_id}") | |
| elif not eval_data.answer: | |
| print(f"⚠️ Empty answer for session {session_id}") | |
| except Exception as e: | |
| print(f"⚠️ Failed to trigger auto-eval: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # 返回流 | |
| return StreamingResponse( | |
| chat_stream_with_eval(), | |
| media_type="text/plain" | |
| ) | |
| # ===== Phase 2: 新增评估端点 ===== | |
| async def evaluate(request: Request): | |
| """ | |
| 评估端点: 接收生成结果,进行多维度评估 | |
| POST /evaluate | |
| { | |
| "query": "用户问题", | |
| "retrieved_context": "检索到的文件内容", | |
| "generated_answer": "生成的回答", | |
| "session_id": "会话ID", | |
| "repo_url": "仓库URL(可选)" | |
| } | |
| """ | |
| try: | |
| data = await request.json() | |
| # 提取必需字段 | |
| query = data.get("query") | |
| retrieved_context = data.get("retrieved_context", "") | |
| generated_answer = data.get("generated_answer") | |
| session_id = data.get("session_id", "unknown") | |
| repo_url = data.get("repo_url", "") | |
| if not query or not generated_answer: | |
| return { | |
| "error": "Missing required fields: query, generated_answer", | |
| "status": "failed" | |
| } | |
| # 调用评估引擎获取生成层指标 | |
| generation_metrics = await eval_engine.evaluate_generation( | |
| query=query, | |
| retrieved_context=retrieved_context, | |
| generated_answer=generated_answer | |
| ) | |
| # 构建完整的评估结果对象 | |
| evaluation_result = EvaluationResult( | |
| session_id=session_id, | |
| query=query, | |
| repo_url=repo_url, | |
| timestamp=datetime.now(), | |
| language="en", | |
| generation_metrics=generation_metrics | |
| ) | |
| # 计算综合得分 | |
| evaluation_result.compute_overall_score() | |
| # 数据路由: 根据得分将样本分类 | |
| quality_tier = data_router.route_sample(evaluation_result) | |
| return { | |
| "status": "success", | |
| "evaluation": { | |
| "faithfulness": generation_metrics.faithfulness, | |
| "answer_relevance": generation_metrics.answer_relevance, | |
| "answer_completeness": generation_metrics.answer_completeness, | |
| "overall_score": evaluation_result.overall_score | |
| }, | |
| "quality_tier": quality_tier, | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "error": str(e), | |
| "status": "failed" | |
| } | |
| # ===== 自动评估相关端点 ===== | |
| async def get_review_queue(): | |
| """ | |
| 获取需要人工审查的样本列表 | |
| 这些是评估出现异常(自己的分数和Ragas分数差异过大)的样本 | |
| 需要人工判断哪个评估器更准确 | |
| GET /auto-eval/review-queue | |
| """ | |
| try: | |
| auto_eval_service = get_auto_evaluation_service() | |
| if not auto_eval_service: | |
| return {"error": "Auto evaluation service not initialized", "status": "failed"} | |
| queue = auto_eval_service.get_review_queue() | |
| return { | |
| "status": "success", | |
| "queue_size": len(queue), | |
| "samples": [ | |
| { | |
| "index": i, | |
| "query": item["eval_result"].query, | |
| "custom_score": item["custom_score"], | |
| "ragas_score": item["ragas_score"], | |
| "diff": item["diff"], | |
| "quality_tier": item["eval_result"].data_quality_tier.value, | |
| "timestamp": item["timestamp"] | |
| } | |
| for i, item in enumerate(queue) | |
| ] | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "status": "failed"} | |
| async def approve_sample(index: int): | |
| """ | |
| 人工批准某个样本(接受该评估结果) | |
| POST /auto-eval/approve/0 | |
| """ | |
| try: | |
| auto_eval_service = get_auto_evaluation_service() | |
| if not auto_eval_service: | |
| return {"error": "Auto evaluation service not initialized", "status": "failed"} | |
| auto_eval_service.approve_sample(index) | |
| return { | |
| "status": "success", | |
| "message": f"Sample {index} approved and stored" | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "status": "failed"} | |
| async def reject_sample(index: int): | |
| """ | |
| 人工拒绝某个样本(抛弃该评估结果) | |
| POST /auto-eval/reject/0 | |
| """ | |
| try: | |
| auto_eval_service = get_auto_evaluation_service() | |
| if not auto_eval_service: | |
| return {"error": "Auto evaluation service not initialized", "status": "failed"} | |
| auto_eval_service.reject_sample(index) | |
| return { | |
| "status": "success", | |
| "message": f"Sample {index} rejected and removed from queue" | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "status": "failed"} | |
| async def auto_eval_stats(): | |
| """ | |
| 获取自动评估统计信息 | |
| GET /auto-eval/stats | |
| """ | |
| try: | |
| auto_eval_service = get_auto_evaluation_service() | |
| if not auto_eval_service: | |
| return {"error": "Auto evaluation service not initialized", "status": "failed"} | |
| queue = auto_eval_service.get_review_queue() | |
| return { | |
| "status": "success", | |
| "auto_evaluation": { | |
| "enabled": auto_eval_service.config.enabled, | |
| "use_ragas": auto_eval_service.config.use_ragas, | |
| "async_mode": auto_eval_service.config.async_evaluation, | |
| "custom_weight": auto_eval_service.config.custom_weight, | |
| "ragas_weight": auto_eval_service.config.ragas_weight, | |
| "diff_threshold": auto_eval_service.config.diff_threshold | |
| }, | |
| "review_queue_size": len(queue), | |
| "last_update": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "status": "failed"} | |
| async def evaluation_stats(): | |
| """ | |
| 获取评估统计信息 | |
| GET /evaluation/stats | |
| """ | |
| try: | |
| stats = eval_engine.get_statistics() | |
| return { | |
| "status": "success", | |
| "statistics": { | |
| "total_evaluations": stats.get("total_evaluations", 0), | |
| "average_score": stats.get("average_score", 0), | |
| "quality_distribution": stats.get("quality_distribution", {}), | |
| "top_issues": stats.get("top_issues", []) | |
| } | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "status": "failed" | |
| } | |
| async def quality_distribution(): | |
| """ | |
| 获取数据质量分布 (用于仪表盘) | |
| GET /dashboard/quality-distribution | |
| """ | |
| try: | |
| distribution = data_router.get_distribution() | |
| return { | |
| "status": "success", | |
| "distribution": { | |
| "gold": distribution.get("gold", 0), | |
| "silver": distribution.get("silver", 0), | |
| "bronze": distribution.get("bronze", 0), | |
| "rejected": distribution.get("rejected", 0), | |
| "corrected": distribution.get("corrected", 0) | |
| }, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "status": "failed" | |
| } | |
| async def bad_cases(): | |
| """ | |
| 获取低质量样本 (用于人工审核) | |
| GET /dashboard/bad-cases | |
| """ | |
| try: | |
| bad_samples = data_router.get_bad_samples(limit=10) | |
| return { | |
| "status": "success", | |
| "bad_cases": [ | |
| { | |
| "query": s.get("query", ""), | |
| "issue": s.get("issue", ""), | |
| "score": s.get("score", 0) | |
| } | |
| for s in bad_samples | |
| ], | |
| "total_bad_cases": len(bad_samples) | |
| } | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "status": "failed" | |
| } | |
| if __name__ == "__main__": | |
| # 生产模式建议关掉 reload | |
| uvicorn.run("app.main:app", host=settings.HOST, port=settings.PORT, reload=False) |