Spaces:
Sleeping
Sleeping
| """ | |
| REST API cho Vietnam Heritage RAG System | |
| """ | |
| import uuid | |
| import json | |
| import os | |
| import sys | |
| from datetime import datetime | |
| import google.generativeai as genai | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| # Add service directory to sys.path to allow imports | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'service')) | |
| from rewrite import QueryRewriter | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Khởi tạo QueryRewriter (chứa ask_with_context) | |
| rewriter = QueryRewriter() | |
| islog = os.getenv('islog') | |
| metrics_log = [] # Lưu lại các lần đánh giá để dựng biểu đồ | |
| GENAI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if GENAI_API_KEY: | |
| genai.configure(api_key=GENAI_API_KEY) | |
| def _safe_json_parse(text): | |
| """Parse chuỗi JSON, cố gắng trích block {} đầu tiên nếu có thêm text.""" | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| try: | |
| return json.loads(text[start : end + 1]) | |
| except Exception: | |
| return None | |
| return None | |
| def evaluate_answer_llm(question: str, answer: str, history_message): | |
| """Gọi LLM để chấm điểm mức liên quan, độ chính xác và mức độ lan man.""" | |
| if not GENAI_API_KEY: | |
| return { | |
| "status": "skipped", | |
| "reason": "missing_gemini_api_key", | |
| } | |
| try: | |
| model = genai.GenerativeModel("gemini-2.5-flash") | |
| history_text = "\n".join([m.get("content", "") for m in history_message]) if history_message else "" | |
| prompt = ( | |
| "You are an evaluator for a RAG chatbot." | |
| " Return JSON with keys: rag_relevance (0-1), answer_accuracy (0-1), hallucination (bool), notes (string)." | |
| " Evaluate strictly from question and answer (and chat history if provided)." | |
| " rag_relevance measures how well retrieved context seems relevant to the question." | |
| " answer_accuracy measures factual correctness and completeness." | |
| " hallucination is true if the answer includes unrelated, fabricated, or off-topic info." | |
| f"\nQuestion: {question}\nAnswer: {answer}\nHistory: {history_text}\nReturn JSON only." | |
| ) | |
| resp = model.generate_content(prompt) | |
| parsed = _safe_json_parse(resp.text) | |
| if not parsed: | |
| raise ValueError("LLM did not return valid JSON") | |
| rag_rel = float(parsed.get("rag_relevance", 0)) | |
| acc = float(parsed.get("answer_accuracy", 0)) | |
| halluc = bool(parsed.get("hallucination", False)) | |
| return { | |
| "status": "ok", | |
| "timestamp": datetime.utcnow().isoformat() + "Z", | |
| "rag_relevance": max(0.0, min(1.0, rag_rel)), | |
| "answer_accuracy": max(0.0, min(1.0, acc)), | |
| "hallucination": halluc, | |
| "notes": parsed.get("notes", "") or "", | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| } | |
| def ask_api(): | |
| """ | |
| Main endpoint - Gọi ask_with_context | |
| Request body: | |
| { | |
| "question": "Câu hỏi của bạn" | |
| } | |
| Response: | |
| { | |
| "question": "Câu hỏi", | |
| "answer": "Câu trả lời từ RAG" | |
| } | |
| """ | |
| try: | |
| data = request.get_json() | |
| all_messages = data.get("messages", []) | |
| history_message = all_messages[-6:-1] | |
| # if islog == "1": | |
| # for f in history_message: | |
| # print(f) | |
| question = all_messages[-1]["content"] | |
| if not question: | |
| return jsonify({ | |
| "error": "'question' cannot be empty" | |
| }), 400 | |
| # Gọi ask_with_context | |
| answer = rewriter.ask_with_context(question, history_message) | |
| # Đánh giá tự động bằng LLM | |
| # evaluation = evaluate_answer_llm(question, answer, history_message) | |
| # if evaluation: | |
| # metrics_log.append({ | |
| # "question": question, | |
| # "answer": answer, | |
| # "evaluation": evaluation, | |
| # }) | |
| # # Giữ kích thước log vừa phải để hiển thị biểu đồ | |
| # if len(metrics_log) > 200: | |
| # del metrics_log[:-200] | |
| return jsonify({ | |
| "id": str(uuid.uuid4()), | |
| "object": "chat.completion", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": answer | |
| }, | |
| "finish_reason": "stop" | |
| } | |
| ], | |
| "evaluation": "evaluation" | |
| }), 200 | |
| except Exception as e: | |
| return jsonify({ | |
| "error": str(e), | |
| "status": "error" | |
| }), 500 | |
| def lstmodel(): | |
| return jsonify({ | |
| "object": "list", | |
| "data": [ | |
| {"id": "Model-1", "object": "model", "owned_by": "owner"}, | |
| {"id": "Model-2", "object": "model", "owned_by": "owner"} | |
| ] | |
| }), 200 | |
| def health_check(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "service": "Vietnam Heritage RAG API" | |
| }), 200 | |
| def home(): | |
| """API documentation""" | |
| return jsonify({ | |
| "message": "Vietnam Heritage AI REST API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "POST /ask": { | |
| "description": "Ask a question about Vietnamese heritage", | |
| "body": { | |
| "question": "Your question here" | |
| } | |
| }, | |
| "GET /health": "Health check endpoint", | |
| "GET /": "API documentation", | |
| "GET /lstmodel": "List available models" | |
| }, | |
| "example": { | |
| "url": "/ask", | |
| "method": "POST", | |
| "body": { | |
| "question": "Nguyễn Trãi là ai?" | |
| } | |
| } | |
| }), 200 | |
| def get_metrics(): | |
| """Trả về log đánh giá để dựng biểu đồ ở frontend.""" | |
| # Tính trung bình nhanh để tiện hiển thị | |
| rag_scores = [m["evaluation"].get("rag_relevance", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"] | |
| acc_scores = [m["evaluation"].get("answer_accuracy", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"] | |
| halluc_counts = [m["evaluation"].get("hallucination", False) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"] | |
| summary = { | |
| "total": len(metrics_log), | |
| "avg_rag_relevance": sum(rag_scores) / len(rag_scores) if rag_scores else 0, | |
| "avg_answer_accuracy": sum(acc_scores) / len(acc_scores) if acc_scores else 0, | |
| "hallucination_rate": (sum(1 for h in halluc_counts if h) / len(halluc_counts)) if halluc_counts else 0, | |
| } | |
| return jsonify({ | |
| "summary": summary, | |
| "data": metrics_log, | |
| }), 200 | |
| def reset_history(): | |
| """Reset conversation history""" | |
| global history | |
| history = [] | |
| return jsonify({ | |
| "message": "History reset successfully", | |
| "status": "success" | |
| }), 200 | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 5000)) | |
| print("=" * 60) | |
| print(f"🚀 Vietnam Heritage RAG API") | |
| print("=" * 60) | |
| print(f"📍 Server: http://localhost:{port}") | |
| print(f"📝 Endpoints:") | |
| print(f" POST http://localhost:{port}/ask") | |
| print(f" GET http://localhost:{port}/health") | |
| print(f" GET http://localhost:{port}/") | |
| print("=" * 60) | |
| app.run(host='0.0.0.0', port=port, debug=True) | |