CultureBot / app.py
T-Phong
init project
dded31e
"""
REST API cho Vietnam Heritage RAG System
"""
import uuid
import json
import os
import sys
from datetime import datetime
import google.generativeai as genai
from flask import Flask, request, jsonify
from flask_cors import CORS
# Add service directory to sys.path to allow imports
sys.path.append(os.path.join(os.path.dirname(__file__), 'service'))
from rewrite import QueryRewriter
app = Flask(__name__)
CORS(app)
# Khởi tạo QueryRewriter (chứa ask_with_context)
rewriter = QueryRewriter()
islog = os.getenv('islog')
metrics_log = [] # Lưu lại các lần đánh giá để dựng biểu đồ
GENAI_API_KEY = os.getenv("GEMINI_API_KEY")
if GENAI_API_KEY:
genai.configure(api_key=GENAI_API_KEY)
def _safe_json_parse(text):
"""Parse chuỗi JSON, cố gắng trích block {} đầu tiên nếu có thêm text."""
try:
return json.loads(text)
except Exception:
pass
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(text[start : end + 1])
except Exception:
return None
return None
def evaluate_answer_llm(question: str, answer: str, history_message):
"""Gọi LLM để chấm điểm mức liên quan, độ chính xác và mức độ lan man."""
if not GENAI_API_KEY:
return {
"status": "skipped",
"reason": "missing_gemini_api_key",
}
try:
model = genai.GenerativeModel("gemini-2.5-flash")
history_text = "\n".join([m.get("content", "") for m in history_message]) if history_message else ""
prompt = (
"You are an evaluator for a RAG chatbot."
" Return JSON with keys: rag_relevance (0-1), answer_accuracy (0-1), hallucination (bool), notes (string)."
" Evaluate strictly from question and answer (and chat history if provided)."
" rag_relevance measures how well retrieved context seems relevant to the question."
" answer_accuracy measures factual correctness and completeness."
" hallucination is true if the answer includes unrelated, fabricated, or off-topic info."
f"\nQuestion: {question}\nAnswer: {answer}\nHistory: {history_text}\nReturn JSON only."
)
resp = model.generate_content(prompt)
parsed = _safe_json_parse(resp.text)
if not parsed:
raise ValueError("LLM did not return valid JSON")
rag_rel = float(parsed.get("rag_relevance", 0))
acc = float(parsed.get("answer_accuracy", 0))
halluc = bool(parsed.get("hallucination", False))
return {
"status": "ok",
"timestamp": datetime.utcnow().isoformat() + "Z",
"rag_relevance": max(0.0, min(1.0, rag_rel)),
"answer_accuracy": max(0.0, min(1.0, acc)),
"hallucination": halluc,
"notes": parsed.get("notes", "") or "",
}
except Exception as e:
return {
"status": "error",
"error": str(e),
}
@app.route('/v1/chat/completions', methods=['POST'])
def ask_api():
"""
Main endpoint - Gọi ask_with_context
Request body:
{
"question": "Câu hỏi của bạn"
}
Response:
{
"question": "Câu hỏi",
"answer": "Câu trả lời từ RAG"
}
"""
try:
data = request.get_json()
all_messages = data.get("messages", [])
history_message = all_messages[-6:-1]
# if islog == "1":
# for f in history_message:
# print(f)
question = all_messages[-1]["content"]
if not question:
return jsonify({
"error": "'question' cannot be empty"
}), 400
# Gọi ask_with_context
answer = rewriter.ask_with_context(question, history_message)
# Đánh giá tự động bằng LLM
# evaluation = evaluate_answer_llm(question, answer, history_message)
# if evaluation:
# metrics_log.append({
# "question": question,
# "answer": answer,
# "evaluation": evaluation,
# })
# # Giữ kích thước log vừa phải để hiển thị biểu đồ
# if len(metrics_log) > 200:
# del metrics_log[:-200]
return jsonify({
"id": str(uuid.uuid4()),
"object": "chat.completion",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": answer
},
"finish_reason": "stop"
}
],
"evaluation": "evaluation"
}), 200
except Exception as e:
return jsonify({
"error": str(e),
"status": "error"
}), 500
@app.route('/v1/models', methods=['GET'])
def lstmodel():
return jsonify({
"object": "list",
"data": [
{"id": "Model-1", "object": "model", "owned_by": "owner"},
{"id": "Model-2", "object": "model", "owned_by": "owner"}
]
}), 200
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"service": "Vietnam Heritage RAG API"
}), 200
@app.route('/', methods=['GET'])
def home():
"""API documentation"""
return jsonify({
"message": "Vietnam Heritage AI REST API",
"version": "1.0.0",
"endpoints": {
"POST /ask": {
"description": "Ask a question about Vietnamese heritage",
"body": {
"question": "Your question here"
}
},
"GET /health": "Health check endpoint",
"GET /": "API documentation",
"GET /lstmodel": "List available models"
},
"example": {
"url": "/ask",
"method": "POST",
"body": {
"question": "Nguyễn Trãi là ai?"
}
}
}), 200
@app.route('/metrics', methods=['GET'])
def get_metrics():
"""Trả về log đánh giá để dựng biểu đồ ở frontend."""
# Tính trung bình nhanh để tiện hiển thị
rag_scores = [m["evaluation"].get("rag_relevance", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
acc_scores = [m["evaluation"].get("answer_accuracy", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
halluc_counts = [m["evaluation"].get("hallucination", False) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
summary = {
"total": len(metrics_log),
"avg_rag_relevance": sum(rag_scores) / len(rag_scores) if rag_scores else 0,
"avg_answer_accuracy": sum(acc_scores) / len(acc_scores) if acc_scores else 0,
"hallucination_rate": (sum(1 for h in halluc_counts if h) / len(halluc_counts)) if halluc_counts else 0,
}
return jsonify({
"summary": summary,
"data": metrics_log,
}), 200
@app.route('/reset', methods=['POST'])
def reset_history():
"""Reset conversation history"""
global history
history = []
return jsonify({
"message": "History reset successfully",
"status": "success"
}), 200
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000))
print("=" * 60)
print(f"🚀 Vietnam Heritage RAG API")
print("=" * 60)
print(f"📍 Server: http://localhost:{port}")
print(f"📝 Endpoints:")
print(f" POST http://localhost:{port}/ask")
print(f" GET http://localhost:{port}/health")
print(f" GET http://localhost:{port}/")
print("=" * 60)
app.run(host='0.0.0.0', port=port, debug=True)