g8-cs106 / step5_api.py
gracephamit's picture
Upload 29 files
0dd9600 verified
# step5_api_hybrid.py
import os
import json
import numpy as np
import faiss
import pickle
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from underthesea import word_tokenize
import re
MODEL_NAME = "keepitreal/vietnamese-sbert"
KB_JSON = "dataset/knowledge_base.json"
VECTORS_NPY = os.path.join("artifacts", "kb_vectors.npy")
META_JSON = os.path.join("artifacts", "kb_meta.json")
FAISS_INDEX_PATH = os.path.join("artifacts", "kb_faiss.index")
BM25_PKL = os.path.join("artifacts", "bm25_index.pkl")
TOKENIZED_PKL = os.path.join("artifacts", "tokenized_corpus.pkl")
app = FastAPI(title="Hybrid Knowledge Search API", version="3.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ===== Load Resources =====
print("🚀 Loading hybrid search resources...")
model = SentenceTransformer(MODEL_NAME)
vectors = np.load(VECTORS_NPY).astype(np.float32)
with open(META_JSON, "r", encoding="utf-8") as f:
meta = json.load(f)
with open(KB_JSON, "r", encoding="utf-8") as f:
kb = json.load(f)
id_map = {item.get("id"): item for item in kb}
index = faiss.read_index(FAISS_INDEX_PATH)
with open(BM25_PKL, "rb") as f:
bm25 = pickle.load(f)
with open(TOKENIZED_PKL, "rb") as f:
tokenized_corpus = pickle.load(f)
print(f"✅ Loaded {len(meta)} items (hybrid: semantic + BM25 + keyword)")
class SearchRequest(BaseModel):
query: str
top_k: int = 5
semantic_weight: float = 0.4
bm25_weight: float = 0.4
keyword_weight: float = 0.2
def preprocess_query(query: str) -> str:
query = query.lower()
query = re.sub(r'[^\w\sàáạảãâầấậẩẫăằắặẳẵèéẹẻẽêềếệểễìíịỉĩòóọỏõôồốộổỗơờớợởỡùúụủũưừứựửữỳýỵỷỹđ]', ' ', query)
return ' '.join(query.split())
def keyword_match_score(query: str, keywords: list) -> float:
query_words = set(query.lower().split())
item_keywords = set([k.lower() for k in keywords])
if not query_words or not item_keywords:
return 0.0
intersection = query_words & item_keywords
union = query_words | item_keywords
jaccard = len(intersection) / len(union) if union else 0.0
match_ratio = len(intersection) / len(query_words) if query_words else 0.0
return 0.5 * jaccard + 0.5 * match_ratio
def hybrid_search(query: str, top_k: int, weights=(0.4, 0.4, 0.2)):
processed_query = preprocess_query(query)
# 1. Semantic
qv = model.encode(processed_query, normalize_embeddings=True, convert_to_numpy=True)
qv = np.asarray(qv, dtype=np.float32).reshape(1, -1)
retrieve_k = min(top_k * 5, len(meta))
semantic_scores, semantic_idxs = index.search(qv, retrieve_k)
# 2. BM25
try:
query_tokens = word_tokenize(processed_query, format="text").split()
except:
query_tokens = processed_query.split()
bm25_scores = bm25.get_scores(query_tokens)
# 3. Keyword
keyword_scores = np.array([
keyword_match_score(processed_query, m.get("keywords", []))
for m in meta
])
# Normalize
def normalize(scores):
scores = np.array(scores)
return scores / scores.max() if scores.max() > 0 else scores
bm25_scores_norm = normalize(bm25_scores)
# Combine
all_scores = {}
for idx, score in zip(semantic_idxs[0], semantic_scores[0]):
if idx >= 0:
all_scores[idx] = {'semantic': float(score), 'bm25': 0.0, 'keyword': 0.0}
for idx in range(len(meta)):
if idx not in all_scores:
all_scores[idx] = {'semantic': 0.0, 'bm25': 0.0, 'keyword': 0.0}
all_scores[idx]['bm25'] = float(bm25_scores_norm[idx])
all_scores[idx]['keyword'] = float(keyword_scores[idx])
w_sem, w_bm25, w_kw = weights
final_scores = []
for idx, sd in all_scores.items():
final = w_sem * sd['semantic'] + w_bm25 * sd['bm25'] + w_kw * sd['keyword']
final_scores.append((idx, final, sd))
final_scores.sort(key=lambda x: x[1], reverse=True)
# Format
results = []
for rank, (idx, final_score, score_breakdown) in enumerate(final_scores[:top_k], start=1):
m = meta[idx]
item_id = m.get("id")
full = id_map.get(item_id, {})
raw_text = ""
kt = ""
if isinstance(full.get("metadata"), dict):
raw_text = full["metadata"].get("raw_text", "") or ""
kt_val = full["metadata"].get("knowledge_type", "")
kt = ", ".join(kt_val) if isinstance(kt_val, list) else (str(kt_val) if kt_val else "")
results.append({
"rank": rank,
"score": float(final_score),
"relevance": min(100, int(final_score * 100)),
"score_breakdown": {
"semantic": round(score_breakdown['semantic'], 3),
"bm25": round(score_breakdown['bm25'], 3),
"keyword": round(score_breakdown['keyword'], 3),
},
"id": item_id,
"topic": full.get("topic", ""),
"chapter": full.get("chapter", ""),
"knowledge_type": kt,
"raw_text": raw_text,
})
return results
@app.get("/")
def root():
return {"message": "Hybrid Knowledge Search API v3.0", "status": "running"}
@app.get("/health")
def health():
return {
"status": "ok",
"items": len(meta),
"vector_dim": int(vectors.shape[1]),
"faiss_total": int(index.ntotal),
"model": MODEL_NAME,
"search_type": "hybrid (semantic + BM25 + keyword)"
}
@app.post("/search")
def search_api(req: SearchRequest):
q = (req.query or "").strip()
if not q:
raise HTTPException(status_code=400, detail="Query cannot be empty")
top_k = max(1, min(req.top_k, 20))
weights = (req.semantic_weight, req.bm25_weight, req.keyword_weight)
results = hybrid_search(q, top_k, weights)
return {
"query": q,
"processed_query": preprocess_query(q),
"top_k": top_k,
"weights": {
"semantic": req.semantic_weight,
"bm25": req.bm25_weight,
"keyword": req.keyword_weight
},
"results": results
}