|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
from typing import List, Dict |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import faiss |
|
|
|
|
|
class Retriever: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.index = None |
|
|
self.metadata = [] |
|
|
self.max_text_length = 220 |
|
|
|
|
|
|
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
try: |
|
|
print('Загрузка модели эмбеддингов...') |
|
|
self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') |
|
|
print('Модель эмбеддингов загружена успешно') |
|
|
except Exception as e: |
|
|
print(f'Ошибка загрузки модели эмбеддингов: {e}') |
|
|
self.model = None |
|
|
|
|
|
def build_or_load_index(self, courses: List[Dict]): |
|
|
index_path = 'data/index/index.faiss' |
|
|
meta_path = 'data/index/meta.json' |
|
|
|
|
|
if os.path.exists(index_path) and os.path.exists(meta_path): |
|
|
print('Загрузка существующего индекса...') |
|
|
self._load_index() |
|
|
else: |
|
|
print('Создание нового индекса...') |
|
|
self._build_index(courses) |
|
|
|
|
|
def _build_index(self, courses: List[Dict]): |
|
|
if not self.model: |
|
|
print('Модель эмбеддингов недоступна') |
|
|
return |
|
|
|
|
|
|
|
|
texts = [] |
|
|
self.metadata = [] |
|
|
|
|
|
for course in courses: |
|
|
text = f"{course['name']} {course.get('short_desc', '')}" |
|
|
text = text.strip()[:self.max_text_length] |
|
|
|
|
|
if text: |
|
|
texts.append(text) |
|
|
self.metadata.append({ |
|
|
'course_id': course['id'], |
|
|
'program_id': course['program_id'], |
|
|
'semester': course['semester'], |
|
|
'name': course['name'], |
|
|
'credits': course['credits'], |
|
|
'short_desc': course.get('short_desc', ''), |
|
|
'tags': course.get('tags', []) |
|
|
}) |
|
|
|
|
|
if not texts: |
|
|
print('Нет текстов для индексации') |
|
|
return |
|
|
|
|
|
|
|
|
print(f'Создание эмбеддингов для {len(texts)} курсов...') |
|
|
embeddings = self.model.encode(texts, convert_to_tensor=False) |
|
|
|
|
|
|
|
|
faiss.normalize_L2(embeddings) |
|
|
|
|
|
|
|
|
dimension = embeddings.shape[1] |
|
|
self.index = faiss.IndexFlatIP(dimension) |
|
|
self.index.add(embeddings.astype('float32')) |
|
|
|
|
|
|
|
|
self._save_index() |
|
|
|
|
|
print(f'Индекс создан: {len(texts)} курсов, размерность {dimension}') |
|
|
|
|
|
def _save_index(self): |
|
|
os.makedirs('data/index', exist_ok=True) |
|
|
|
|
|
|
|
|
faiss.write_index(self.index, 'data/index/index.faiss') |
|
|
|
|
|
|
|
|
with open('data/index/meta.json', 'w', encoding='utf-8') as f: |
|
|
json.dump(self.metadata, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def _load_index(self): |
|
|
try: |
|
|
|
|
|
self.index = faiss.read_index('data/index/index.faiss') |
|
|
|
|
|
|
|
|
with open('data/index/meta.json', 'r', encoding='utf-8') as f: |
|
|
self.metadata = json.load(f) |
|
|
|
|
|
print(f'Индекс загружен: {len(self.metadata)} курсов') |
|
|
|
|
|
except Exception as e: |
|
|
print(f'Ошибка загрузки индекса: {e}') |
|
|
self.index = None |
|
|
self.metadata = [] |
|
|
|
|
|
def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]: |
|
|
if not self.model or not self.index: |
|
|
print('Модель или индекс недоступны') |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
query_embedding = self.model.encode([query], convert_to_tensor=False) |
|
|
faiss.normalize_L2(query_embedding) |
|
|
|
|
|
|
|
|
scores, indices = self.index.search(query_embedding.astype('float32'), k) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, (score, idx) in enumerate(zip(scores[0], indices[0])): |
|
|
if idx < len(self.metadata) and score >= threshold: |
|
|
result = self.metadata[idx].copy() |
|
|
result['score'] = float(score) |
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f'Ошибка поиска: {e}') |
|
|
return [] |
|
|
|
|
|
def get_index_size(self) -> int: |
|
|
if self.index: |
|
|
return self.index.ntotal |
|
|
return 0 |
|
|
|
|
|
def get_embedding_dim(self) -> int: |
|
|
if self.index: |
|
|
return self.index.d |
|
|
return 0 |
|
|
|