File size: 1,882 Bytes
6a02b16
 
 
 
 
 
c6cece9
6a02b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import faiss
from pathlib import Path
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from data_io import read_jsonl

def detect_lang(s: str):
    s2 = s.lower()
    kz_chars = set("әғқңөұүһі")
    if any(ch in kz_chars for ch in s2):
        return "kz"
    return "ru"

def load_index(lang: str):
    idx_path = Path("artifacts/index") / f"{lang}.faiss"
    meta_path = Path("artifacts/index") / f"{lang}_meta.jsonl"
    index = faiss.read_index(str(idx_path))
    meta = read_jsonl(str(meta_path))
    meta_by_pos = {int(x["pos"]): x for x in meta}
    return index, meta_by_pos

model_path = "artifacts/models/finetuned_mpnet" if Path("artifacts/models/finetuned_mpnet").exists() else "paraphrase-multilingual-mpnet-base-v2"
model = SentenceTransformer(model_path)
ru_index, ru_meta = load_index("ru")
kz_index, kz_meta = load_index("kz")

app = FastAPI()

class SearchReq(BaseModel):
    query: str
    top_k: int = 5
    lang: str | None = None

@app.post("/search")
def search(req: SearchReq):
    q = req.query.strip()
    if not q:
        return {"error": "empty_query"}
    lang = req.lang if req.lang in {"ru", "kz"} else detect_lang(q)
    index, meta_by_pos = (kz_index, kz_meta) if lang == "kz" else (ru_index, ru_meta)

    emb = model.encode([q], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
    scores, idxs = index.search(emb, max(1, min(req.top_k, 50)))

    out = []
    for r in range(idxs.shape[1]):
        pos = int(idxs[0, r])
        item = meta_by_pos.get(pos)
        if item is None:
            continue
        out.append({"rank": r+1, "score": float(scores[0, r]), "id": item["id"], "meta": item.get("meta"), "text": item.get("text")})
    return {"lang": lang, "results": out}