| """ |
| api_server.py - FastAPI backend (RAG + vLLM) |
| Chạy: python api_server.py |
| Port: 8000 |
| """ |
| import os, re, json |
| from typing import List, Dict, Any |
| from fastapi import FastAPI |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import chromadb |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| import httpx |
| import asyncio |
|
|
| |
| VLLM_URL = "http://localhost:8001/v1/chat/completions" |
| CHROMA_DB_PATH = "./chroma_db" |
| EMBED_MODEL = "intfloat/multilingual-e5-large" |
| TOP_K = 5 |
|
|
| |
| print("Loading embedding model...") |
| embed_model = SentenceTransformer(EMBED_MODEL) |
| print("Loading ChromaDB...") |
| chroma_client = chromadb.PersistentClient(path=CHROMA_DB_PATH, settings=Settings(anonymized_telemetry=False)) |
| collection = chroma_client.get_collection("hdmt_cases") |
| print("Loading cross-reference maps...") |
| try: |
| with open("cross_ref_maps.json", "r") as f: |
| cross_ref = json.load(f) |
| fpga_map = cross_ref.get("fpga", {}) |
| pin_map = cross_ref.get("pin", {}) |
| except: |
| fpga_map, pin_map = {}, {} |
|
|
| |
| app = FastAPI(title="HDMT RAG API", version="2.0") |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| class QueryRequest(BaseModel): |
| query: str |
| history: List[Dict[str, str]] = [] |
| top_k: int = TOP_K |
|
|
| class QueryResponse(BaseModel): |
| answer: str |
| sources: List[Dict[str, Any]] |
| confidence: float |
|
|
| def extract_components(text: str) -> List[str]: |
| return list(set(re.findall(r'\b([A-Z]{1,3}\d{1,4}[A-Z]?)\b', str(text)))) |
|
|
| def get_xref(components: List[str]) -> str: |
| info = [] |
| for c in components: |
| if c in pin_map: |
| info.append(f"- {c}: {pin_map[c]['desc']} (pins {pin_map[c]['pins']})") |
| if c in fpga_map: |
| info.append(f"- {c}: FPGA {fpga_map[c]['fpga']}, PMU {fpga_map[c]['pmu']}") |
| return "\n".join(info) if info else "Khong co thong tin." |
|
|
| def build_prompt(query: str, cases: List[Dict]) -> str: |
| case_text = "" |
| for i, c in enumerate(cases, 1): |
| m = c["metadata"] |
| xref = get_xref(m.get("comps_found", [])) |
| case_text += f"""CASE #{i}: |
| Mo ta loi: {m.get('failure_desc', 'N/A')} |
| Board: {m.get('board_type', 'N/A')} |
| Ket qua: {m.get('result', 'N/A')} |
| Hanh dong: {m.get('action_taken', 'N/A')} |
| Linh kien: {m.get('components', 'N/A')} |
| BKM: {m.get('bkm_procedure', 'N/A')} |
| BKM Components: {m.get('bkm_components', 'N/A')} |
| Priority Replace: {m.get('priority_replace', 'N/A')} |
| Best Actions: {m.get('best_actions', 'N/A')} |
| Thong tin ky thuat: |
| {xref} |
| """ |
| |
| query_xref = get_xref(extract_components(query)) |
| return f"""Ban la chuyen gia debug HDMT. Duoi day la cac truong hop lich su tuong tu. |
| |
| CAU HOI: "{query}" |
| |
| THONG TIN KY THUAT TU SO DO: |
| {query_xref} |
| |
| CAC CASES LICH SU ({len(cases)} cases): |
| {case_text} |
| |
| YEU CAU: Hay phan tich va tra loi bang tieng Viet co dau: |
| 1. PHAN TICH LOI: Giai thich loi gi, lien quan board/kenh nao. |
| 2. QUY TRINH DEBUG (BKM): Viet tung buoc cu the tu BKM_Procedure. |
| 3. LINH KIEN THEO BKM: Lie ke day du tu BKM_Focus_Components. |
| 4. THONG KE THUC TE - LINH KIEN THAY NHIEU NHAT: Tu Priority Replace. |
| 5. HANH DONG HIEU QUA NHAT: Tu Best Actions Weighted. |
| 6. KET LUAN: Nen lam gi truoc, gi sau. Neu BKM khong hieu qua thi fallback theo stats. |
| |
| QUAN TRONG: |
| - PHAI tach biet ro: Linh kien BKM khac voi Linh kien tu stats |
| - PHAI dung so lieu cu the (pass rate %, so lan) |
| - PHAI viet tieng Viet CO DAU |
| - KHONG dung emoji, icon |
| - KHONG gop chung BKM va Stats |
| """ |
|
|
| @app.get("/") |
| async def root(): |
| return {"status": "HDMT RAG API", "model": "Qwen2.5-72B-AWQ", "db": "ChromaDB", "cases": collection.count()} |
|
|
| @app.post("/query", response_model=QueryResponse) |
| async def query(req: QueryRequest): |
| |
| q_emb = embed_model.encode([f"query: {req.query}"], normalize_embeddings=True).tolist()[0] |
| |
| results = collection.query(query_embeddings=[q_emb], n_results=req.top_k, include=["documents", "metadatas", "distances"]) |
| sources = [] |
| for i in range(len(results["ids"][0])): |
| sources.append({ |
| "id": results["ids"][0][i], |
| "similarity": round(1.0 - results["distances"][0][i], 3), |
| "metadata": results["metadatas"][0][i], |
| "document": results["documents"][0][i][:500] |
| }) |
| |
| prompt = build_prompt(req.query, sources) |
| |
| messages = [] |
| for h in req.history[-4:]: |
| messages.append({"role": h["role"], "content": h["content"]}) |
| messages.append({"role": "user", "content": prompt}) |
| try: |
| async with httpx.AsyncClient(timeout=120.0) as client: |
| r = await client.post(VLLM_URL, json={ |
| "model": "Qwen/Qwen2.5-72B-Instruct-AWQ", |
| "messages": messages, "max_tokens": 4000, "temperature": 0.3, "stream": False |
| }) |
| r.raise_for_status() |
| answer = r.json()["choices"][0]["message"]["content"] |
| except Exception as e: |
| answer = f"Loi LLM: {str(e)}\n\nDu lieu tho:\n" |
| for i, s in enumerate(sources[:3], 1): |
| m = s["metadata"] |
| answer += f"\nCASE {i} (tuong dong: {s['similarity']}):\n" |
| answer += f"- Loi: {m.get('failure_desc', 'N/A')}\n" |
| answer += f"- Hanh dong: {m.get('action_taken', 'N/A')}\n" |
| answer += f"- BKM: {m.get('bkm_procedure', 'N/A')}\n" |
| conf = sum(s["similarity"] for s in sources) / len(sources) * 100 if sources else 0 |
| return QueryResponse(answer=answer, sources=sources, confidence=round(min(conf, 99.9), 1)) |
|
|
| @app.post("/feedback") |
| async def feedback(case_id: str, component: str, result: str): |
| fb = {"case_id": case_id, "component": component, "result": result, "ts": asyncio.get_event_loop().time()} |
| with open("feedback_log.jsonl", "a") as f: |
| f.write(json.dumps(fb) + "\n") |
| return {"status": "recorded", "data": fb} |
|
|
| @app.get("/stats") |
| async def stats(): |
| try: |
| with open("stats_summary.json", "r") as f: |
| return {"total_cases": collection.count(), "stats": json.load(f)} |
| except: |
| return {"total_cases": collection.count()} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| print("HDMT RAG API - http://localhost:8000") |
| uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|