| from pathlib import Path | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from data_io import read_jsonl | |
| MODEL_PATH = Path("artifacts/models/finetuned_mpnet") | |
| INDEX_DIR = Path("artifacts/indexes/finetuned") | |
| TOP_K = 5 | |
| def detect_lang(s: str): | |
| s2 = s.lower() | |
| kz_chars = set("әғқңөұүһі") | |
| if any(ch in kz_chars for ch in s2): | |
| return "kz" | |
| return "ru" | |
| def load_index(lang: str): | |
| index = faiss.read_index(str(INDEX_DIR / f"{lang}.faiss")) | |
| meta = read_jsonl(str(INDEX_DIR / f"{lang}_meta.jsonl")) | |
| meta_by_pos = {int(x["pos"]): x for x in meta} | |
| return index, meta_by_pos | |
| def search(model, index, meta_by_pos, query): | |
| q = model.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32) | |
| scores, idxs = index.search(q, TOP_K) | |
| results = [] | |
| for r in range(TOP_K): | |
| item = meta_by_pos.get(int(idxs[0, r])) | |
| if not item: | |
| continue | |
| results.append(item) | |
| return results | |
| def print_results(query, results): | |
| print("\nЗапрос:") | |
| print(query) | |
| print("\nРелевантные нормы:\n") | |
| for i, item in enumerate(results, 1): | |
| print(f"{i}. {item.get('meta', '')}") | |
| print(item.get("text", "")) | |
| print("-" * 80) | |
| def main(): | |
| if not MODEL_PATH.exists(): | |
| raise RuntimeError("finetuned model not found") | |
| model = SentenceTransformer(str(MODEL_PATH)) | |
| ru_index, ru_meta = load_index("ru") | |
| kz_index, kz_meta = load_index("kz") | |
| print("LexIR demo (fine-tuned model)") | |
| print("Введите запрос (exit — выход)\n") | |
| while True: | |
| q = input(">>> ").strip() | |
| if not q: | |
| continue | |
| if q.lower() in {"exit", "quit"}: | |
| break | |
| lang = detect_lang(q) | |
| if lang == "kz": | |
| results = search(model, kz_index, kz_meta, q) | |
| else: | |
| results = search(model, ru_index, ru_meta, q) | |
| print_results(q, results) | |
| if __name__ == "__main__": | |
| main() | |