|
|
import numpy as np
|
|
|
import faiss
|
|
|
from pathlib import Path
|
|
|
from fastapi import FastAPI
|
|
|
from pydantic import BaseModel
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
from data_io import read_jsonl
|
|
|
|
|
|
def detect_lang(s: str):
|
|
|
s2 = s.lower()
|
|
|
kz_chars = set("әғқңөұүһі")
|
|
|
if any(ch in kz_chars for ch in s2):
|
|
|
return "kz"
|
|
|
return "ru"
|
|
|
|
|
|
def load_index(lang: str):
|
|
|
idx_path = Path("artifacts/index") / f"{lang}.faiss"
|
|
|
meta_path = Path("artifacts/index") / f"{lang}_meta.jsonl"
|
|
|
index = faiss.read_index(str(idx_path))
|
|
|
meta = read_jsonl(str(meta_path))
|
|
|
meta_by_pos = {int(x["pos"]): x for x in meta}
|
|
|
return index, meta_by_pos
|
|
|
|
|
|
model_path = "artifacts/models/finetuned_mpnet" if Path("artifacts/models/finetuned_mpnet").exists() else "paraphrase-multilingual-mpnet-base-v2"
|
|
|
model = SentenceTransformer(model_path)
|
|
|
ru_index, ru_meta = load_index("ru")
|
|
|
kz_index, kz_meta = load_index("kz")
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
class SearchReq(BaseModel):
|
|
|
query: str
|
|
|
top_k: int = 5
|
|
|
lang: str | None = None
|
|
|
|
|
|
@app.post("/search")
|
|
|
def search(req: SearchReq):
|
|
|
q = req.query.strip()
|
|
|
if not q:
|
|
|
return {"error": "empty_query"}
|
|
|
lang = req.lang if req.lang in {"ru", "kz"} else detect_lang(q)
|
|
|
index, meta_by_pos = (kz_index, kz_meta) if lang == "kz" else (ru_index, ru_meta)
|
|
|
|
|
|
emb = model.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
|
|
|
scores, idxs = index.search(emb, max(1, min(req.top_k, 50)))
|
|
|
|
|
|
out = []
|
|
|
for r in range(idxs.shape[1]):
|
|
|
pos = int(idxs[0, r])
|
|
|
item = meta_by_pos.get(pos)
|
|
|
if item is None:
|
|
|
continue
|
|
|
out.append({"rank": r+1, "score": float(scores[0, r]), "id": item["id"], "meta": item.get("meta"), "text": item.get("text")})
|
|
|
return {"lang": lang, "results": out}
|
|
|
|