test1 / retriever.py
vydrking's picture
Upload 18 files
5071500 verified
import os
import json
import numpy as np
from typing import List, Dict
from sentence_transformers import SentenceTransformer
import faiss
class Retriever:
def __init__(self):
self.model = None
self.index = None
self.metadata = []
self.max_text_length = 220
# Ленивая загрузка модели
self._load_model()
def _load_model(self):
try:
print('Загрузка модели эмбеддингов...')
self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
print('Модель эмбеддингов загружена успешно')
except Exception as e:
print(f'Ошибка загрузки модели эмбеддингов: {e}')
self.model = None
def build_or_load_index(self, courses: List[Dict]):
index_path = 'data/index/index.faiss'
meta_path = 'data/index/meta.json'
if os.path.exists(index_path) and os.path.exists(meta_path):
print('Загрузка существующего индекса...')
self._load_index()
else:
print('Создание нового индекса...')
self._build_index(courses)
def _build_index(self, courses: List[Dict]):
if not self.model:
print('Модель эмбеддингов недоступна')
return
# Подготовка текстов для эмбеддингов
texts = []
self.metadata = []
for course in courses:
text = f"{course['name']} {course.get('short_desc', '')}"
text = text.strip()[:self.max_text_length]
if text:
texts.append(text)
self.metadata.append({
'course_id': course['id'],
'program_id': course['program_id'],
'semester': course['semester'],
'name': course['name'],
'credits': course['credits'],
'short_desc': course.get('short_desc', ''),
'tags': course.get('tags', [])
})
if not texts:
print('Нет текстов для индексации')
return
# Создание эмбеддингов
print(f'Создание эмбеддингов для {len(texts)} курсов...')
embeddings = self.model.encode(texts, convert_to_tensor=False)
# Нормализация эмбеддингов
faiss.normalize_L2(embeddings)
# Создание FAISS индекса
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension)
self.index.add(embeddings.astype('float32'))
# Сохранение индекса и метаданных
self._save_index()
print(f'Индекс создан: {len(texts)} курсов, размерность {dimension}')
def _save_index(self):
os.makedirs('data/index', exist_ok=True)
# Сохранение FAISS индекса
faiss.write_index(self.index, 'data/index/index.faiss')
# Сохранение метаданных
with open('data/index/meta.json', 'w', encoding='utf-8') as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
def _load_index(self):
try:
# Загрузка FAISS индекса
self.index = faiss.read_index('data/index/index.faiss')
# Загрузка метаданных
with open('data/index/meta.json', 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
print(f'Индекс загружен: {len(self.metadata)} курсов')
except Exception as e:
print(f'Ошибка загрузки индекса: {e}')
self.index = None
self.metadata = []
def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]:
if not self.model or not self.index:
print('Модель или индекс недоступны')
return []
try:
# Создание эмбеддинга запроса
query_embedding = self.model.encode([query], convert_to_tensor=False)
faiss.normalize_L2(query_embedding)
# Поиск в индексе
scores, indices = self.index.search(query_embedding.astype('float32'), k)
# Формирование результатов
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < len(self.metadata) and score >= threshold:
result = self.metadata[idx].copy()
result['score'] = float(score)
results.append(result)
return results
except Exception as e:
print(f'Ошибка поиска: {e}')
return []
def get_index_size(self) -> int:
if self.index:
return self.index.ntotal
return 0
def get_embedding_dim(self) -> int:
if self.index:
return self.index.d
return 0