File size: 6,313 Bytes
c386114 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
api_server.py - FastAPI backend (RAG + vLLM)
Chạy: python api_server.py
Port: 8000
"""
import os, re, json
from typing import List, Dict, Any
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import httpx
import asyncio
# CONFIG
VLLM_URL = "http://localhost:8001/v1/chat/completions"
CHROMA_DB_PATH = "./chroma_db"
EMBED_MODEL = "intfloat/multilingual-e5-large"
TOP_K = 5
# LOAD
print("Loading embedding model...")
embed_model = SentenceTransformer(EMBED_MODEL)
print("Loading ChromaDB...")
chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH, settings=Settings(anonymized_telemetry=False))
collection = chroma_client.get_collection("hdmt_cases")
print("Loading cross-reference maps...")
try:
with open("cross_ref_maps.json", "r") as f:
cross_ref = json.load(f)
fpga_map = cross_ref.get("fpga", {})
pin_map = cross_ref.get("pin", {})
except:
fpga_map, pin_map = {}, {}
# FASTAPI
app = FastAPI(title="HDMT RAG API", version="2.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
class QueryRequest(BaseModel):
query: str
history: List[Dict[str, str]] = []
top_k: int = TOP_K
class QueryResponse(BaseModel):
answer: str
sources: List[Dict[str, Any]]
confidence: float
def extract_components(text: str) -> List[str]:
return list(set(re.findall(r'\b([A-Z]{1,3}\d{1,4}[A-Z]?)\b', str(text))))
def get_xref(components: List[str]) -> str:
info = []
for c in components:
if c in pin_map:
info.append(f"- {c}: {pin_map[c]['desc']} (pins {pin_map[c]['pins']})")
if c in fpga_map:
info.append(f"- {c}: FPGA {fpga_map[c]['fpga']}, PMU {fpga_map[c]['pmu']}")
return "\n".join(info) if info else "Khong co thong tin."
def build_prompt(query: str, cases: List[Dict]) -> str:
case_text = ""
for i, c in enumerate(cases, 1):
m = c["metadata"]
xref = get_xref(m.get("comps_found", []))
case_text += f"""CASE #{i}:
Mo ta loi: {m.get('failure_desc', 'N/A')}
Board: {m.get('board_type', 'N/A')}
Ket qua: {m.get('result', 'N/A')}
Hanh dong: {m.get('action_taken', 'N/A')}
Linh kien: {m.get('components', 'N/A')}
BKM: {m.get('bkm_procedure', 'N/A')}
BKM Components: {m.get('bkm_components', 'N/A')}
Priority Replace: {m.get('priority_replace', 'N/A')}
Best Actions: {m.get('best_actions', 'N/A')}
Thong tin ky thuat:
{xref}
"""
query_xref = get_xref(extract_components(query))
return f"""Ban la chuyen gia debug HDMT. Duoi day la cac truong hop lich su tuong tu.
CAU HOI: "{query}"
THONG TIN KY THUAT TU SO DO:
{query_xref}
CAC CASES LICH SU ({len(cases)} cases):
{case_text}
YEU CAU: Hay phan tich va tra loi bang tieng Viet co dau:
1. PHAN TICH LOI: Giai thich loi gi, lien quan board/kenh nao.
2. QUY TRINH DEBUG (BKM): Viet tung buoc cu the tu BKM_Procedure.
3. LINH KIEN THEO BKM: Lie ke day du tu BKM_Focus_Components.
4. THONG KE THUC TE - LINH KIEN THAY NHIEU NHAT: Tu Priority Replace.
5. HANH DONG HIEU QUA NHAT: Tu Best Actions Weighted.
6. KET LUAN: Nen lam gi truoc, gi sau. Neu BKM khong hieu qua thi fallback theo stats.
QUAN TRONG:
- PHAI tach biet ro: Linh kien BKM khac voi Linh kien tu stats
- PHAI dung so lieu cu the (pass rate %, so lan)
- PHAI viet tieng Viet CO DAU
- KHONG dung emoji, icon
- KHONG gop chung BKM va Stats
"""
@app.get("/")
async def root():
return {"status": "HDMT RAG API", "model": "Qwen2.5-72B-AWQ", "db": "ChromaDB", "cases": collection.count()}
@app.post("/query", response_model=QueryResponse)
async def query(req: QueryRequest):
# 1. Embed query
q_emb = embed_model.encode([f"query: {req.query}"], normalize_embeddings=True).tolist()[0]
# 2. Search
results = collection.query(query_embeddings=[q_emb], n_results=req.top_k, include=["documents", "metadatas", "distances"])
sources = []
for i in range(len(results["ids"][0])):
sources.append({
"id": results["ids"][0][i],
"similarity": round(1.0 - results["distances"][0][i], 3),
"metadata": results["metadatas"][0][i],
"document": results["documents"][0][i][:500]
})
# 3. Build prompt
prompt = build_prompt(req.query, sources)
# 4. Call vLLM
messages = []
for h in req.history[-4:]:
messages.append({"role": h["role"], "content": h["content"]})
messages.append({"role": "user", "content": prompt})
try:
async with httpx.AsyncClient(timeout=120.0) as client:
r = await client.post(VLLM_URL, json={
"model": "Qwen/Qwen2.5-72B-Instruct-AWQ",
"messages": messages, "max_tokens": 4000, "temperature": 0.3, "stream": False
})
r.raise_for_status()
answer = r.json()["choices"][0]["message"]["content"]
except Exception as e:
answer = f"Loi LLM: {str(e)}\n\nDu lieu tho:\n"
for i, s in enumerate(sources[:3], 1):
m = s["metadata"]
answer += f"\nCASE {i} (tuong dong: {s['similarity']}):\n"
answer += f"- Loi: {m.get('failure_desc', 'N/A')}\n"
answer += f"- Hanh dong: {m.get('action_taken', 'N/A')}\n"
answer += f"- BKM: {m.get('bkm_procedure', 'N/A')}\n"
conf = sum(s["similarity"] for s in sources) / len(sources) * 100 if sources else 0
return QueryResponse(answer=answer, sources=sources, confidence=round(min(conf, 99.9), 1))
@app.post("/feedback")
async def feedback(case_id: str, component: str, result: str):
fb = {"case_id": case_id, "component": component, "result": result, "ts": asyncio.get_event_loop().time()}
with open("feedback_log.jsonl", "a") as f:
f.write(json.dumps(fb) + "\n")
return {"status": "recorded", "data": fb}
@app.get("/stats")
async def stats():
try:
with open("stats_summary.json", "r") as f:
return {"total_cases": collection.count(), "stats": json.load(f)}
except:
return {"total_cases": collection.count()}
if __name__ == "__main__":
import uvicorn
print("HDMT RAG API - http://localhost:8000")
uvicorn.run(app, host="0.0.0.0", port=8000)
|