lexir / src /api.py
irinaqqq's picture
ADDED MORE GPAPHS
c6cece9
import numpy as np
import faiss
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from data_io import read_jsonl
def detect_lang(s: str):
s2 = s.lower()
kz_chars = set("әғқңөұүһі")
if any(ch in kz_chars for ch in s2):
return "kz"
return "ru"
def load_index(lang: str):
idx_path = Path("artifacts/index") / f"{lang}.faiss"
meta_path = Path("artifacts/index") / f"{lang}_meta.jsonl"
index = faiss.read_index(str(idx_path))
meta = read_jsonl(str(meta_path))
meta_by_pos = {int(x["pos"]): x for x in meta}
return index, meta_by_pos
model_path = "artifacts/models/finetuned_mpnet" if Path("artifacts/models/finetuned_mpnet").exists() else "paraphrase-multilingual-mpnet-base-v2"
model = SentenceTransformer(model_path)
ru_index, ru_meta = load_index("ru")
kz_index, kz_meta = load_index("kz")
app = FastAPI()
class SearchReq(BaseModel):
query: str
top_k: int = 5
lang: str | None = None
@app.post("/search")
def search(req: SearchReq):
q = req.query.strip()
if not q:
return {"error": "empty_query"}
lang = req.lang if req.lang in {"ru", "kz"} else detect_lang(q)
index, meta_by_pos = (kz_index, kz_meta) if lang == "kz" else (ru_index, ru_meta)
emb = model.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
scores, idxs = index.search(emb, max(1, min(req.top_k, 50)))
out = []
for r in range(idxs.shape[1]):
pos = int(idxs[0, r])
item = meta_by_pos.get(pos)
if item is None:
continue
out.append({"rank": r+1, "score": float(scores[0, r]), "id": item["id"], "meta": item.get("meta"), "text": item.get("text")})
return {"lang": lang, "results": out}