Spaces:
Sleeping
Sleeping
| # step5_api_hybrid.py | |
| import os | |
| import json | |
| import numpy as np | |
| import faiss | |
| import pickle | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from underthesea import word_tokenize | |
| import re | |
| MODEL_NAME = "keepitreal/vietnamese-sbert" | |
| KB_JSON = "dataset/knowledge_base.json" | |
| VECTORS_NPY = os.path.join("artifacts", "kb_vectors.npy") | |
| META_JSON = os.path.join("artifacts", "kb_meta.json") | |
| FAISS_INDEX_PATH = os.path.join("artifacts", "kb_faiss.index") | |
| BM25_PKL = os.path.join("artifacts", "bm25_index.pkl") | |
| TOKENIZED_PKL = os.path.join("artifacts", "tokenized_corpus.pkl") | |
| app = FastAPI(title="Hybrid Knowledge Search API", version="3.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ===== Load Resources ===== | |
| print("🚀 Loading hybrid search resources...") | |
| model = SentenceTransformer(MODEL_NAME) | |
| vectors = np.load(VECTORS_NPY).astype(np.float32) | |
| with open(META_JSON, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| with open(KB_JSON, "r", encoding="utf-8") as f: | |
| kb = json.load(f) | |
| id_map = {item.get("id"): item for item in kb} | |
| index = faiss.read_index(FAISS_INDEX_PATH) | |
| with open(BM25_PKL, "rb") as f: | |
| bm25 = pickle.load(f) | |
| with open(TOKENIZED_PKL, "rb") as f: | |
| tokenized_corpus = pickle.load(f) | |
| print(f"✅ Loaded {len(meta)} items (hybrid: semantic + BM25 + keyword)") | |
| class SearchRequest(BaseModel): | |
| query: str | |
| top_k: int = 5 | |
| semantic_weight: float = 0.4 | |
| bm25_weight: float = 0.4 | |
| keyword_weight: float = 0.2 | |
| def preprocess_query(query: str) -> str: | |
| query = query.lower() | |
| query = re.sub(r'[^\w\sàáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', ' ', query) | |
| return ' '.join(query.split()) | |
| def keyword_match_score(query: str, keywords: list) -> float: | |
| query_words = set(query.lower().split()) | |
| item_keywords = set([k.lower() for k in keywords]) | |
| if not query_words or not item_keywords: | |
| return 0.0 | |
| intersection = query_words & item_keywords | |
| union = query_words | item_keywords | |
| jaccard = len(intersection) / len(union) if union else 0.0 | |
| match_ratio = len(intersection) / len(query_words) if query_words else 0.0 | |
| return 0.5 * jaccard + 0.5 * match_ratio | |
| def hybrid_search(query: str, top_k: int, weights=(0.4, 0.4, 0.2)): | |
| processed_query = preprocess_query(query) | |
| # 1. Semantic | |
| qv = model.encode(processed_query, normalize_embeddings=True, convert_to_numpy=True) | |
| qv = np.asarray(qv, dtype=np.float32).reshape(1, -1) | |
| retrieve_k = min(top_k * 5, len(meta)) | |
| semantic_scores, semantic_idxs = index.search(qv, retrieve_k) | |
| # 2. BM25 | |
| try: | |
| query_tokens = word_tokenize(processed_query, format="text").split() | |
| except: | |
| query_tokens = processed_query.split() | |
| bm25_scores = bm25.get_scores(query_tokens) | |
| # 3. Keyword | |
| keyword_scores = np.array([ | |
| keyword_match_score(processed_query, m.get("keywords", [])) | |
| for m in meta | |
| ]) | |
| # Normalize | |
| def normalize(scores): | |
| scores = np.array(scores) | |
| return scores / scores.max() if scores.max() > 0 else scores | |
| bm25_scores_norm = normalize(bm25_scores) | |
| # Combine | |
| all_scores = {} | |
| for idx, score in zip(semantic_idxs[0], semantic_scores[0]): | |
| if idx >= 0: | |
| all_scores[idx] = {'semantic': float(score), 'bm25': 0.0, 'keyword': 0.0} | |
| for idx in range(len(meta)): | |
| if idx not in all_scores: | |
| all_scores[idx] = {'semantic': 0.0, 'bm25': 0.0, 'keyword': 0.0} | |
| all_scores[idx]['bm25'] = float(bm25_scores_norm[idx]) | |
| all_scores[idx]['keyword'] = float(keyword_scores[idx]) | |
| w_sem, w_bm25, w_kw = weights | |
| final_scores = [] | |
| for idx, sd in all_scores.items(): | |
| final = w_sem * sd['semantic'] + w_bm25 * sd['bm25'] + w_kw * sd['keyword'] | |
| final_scores.append((idx, final, sd)) | |
| final_scores.sort(key=lambda x: x[1], reverse=True) | |
| # Format | |
| results = [] | |
| for rank, (idx, final_score, score_breakdown) in enumerate(final_scores[:top_k], start=1): | |
| m = meta[idx] | |
| item_id = m.get("id") | |
| full = id_map.get(item_id, {}) | |
| raw_text = "" | |
| kt = "" | |
| if isinstance(full.get("metadata"), dict): | |
| raw_text = full["metadata"].get("raw_text", "") or "" | |
| kt_val = full["metadata"].get("knowledge_type", "") | |
| kt = ", ".join(kt_val) if isinstance(kt_val, list) else (str(kt_val) if kt_val else "") | |
| results.append({ | |
| "rank": rank, | |
| "score": float(final_score), | |
| "relevance": min(100, int(final_score * 100)), | |
| "score_breakdown": { | |
| "semantic": round(score_breakdown['semantic'], 3), | |
| "bm25": round(score_breakdown['bm25'], 3), | |
| "keyword": round(score_breakdown['keyword'], 3), | |
| }, | |
| "id": item_id, | |
| "topic": full.get("topic", ""), | |
| "chapter": full.get("chapter", ""), | |
| "knowledge_type": kt, | |
| "raw_text": raw_text, | |
| }) | |
| return results | |
| def root(): | |
| return {"message": "Hybrid Knowledge Search API v3.0", "status": "running"} | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "items": len(meta), | |
| "vector_dim": int(vectors.shape[1]), | |
| "faiss_total": int(index.ntotal), | |
| "model": MODEL_NAME, | |
| "search_type": "hybrid (semantic + BM25 + keyword)" | |
| } | |
| def search_api(req: SearchRequest): | |
| q = (req.query or "").strip() | |
| if not q: | |
| raise HTTPException(status_code=400, detail="Query cannot be empty") | |
| top_k = max(1, min(req.top_k, 20)) | |
| weights = (req.semantic_weight, req.bm25_weight, req.keyword_weight) | |
| results = hybrid_search(q, top_k, weights) | |
| return { | |
| "query": q, | |
| "processed_query": preprocess_query(q), | |
| "top_k": top_k, | |
| "weights": { | |
| "semantic": req.semantic_weight, | |
| "bm25": req.bm25_weight, | |
| "keyword": req.keyword_weight | |
| }, | |
| "results": results | |
| } |