Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from urllib.parse import urlparse
|
|
| 12 |
import logging
|
| 13 |
from sklearn.preprocessing import normalize
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
| 15 |
|
| 16 |
# Настройка логирования
|
| 17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -31,12 +32,20 @@ db_params = {
|
|
| 31 |
"sslmode": "require"
|
| 32 |
}
|
| 33 |
|
| 34 |
-
# Загружаем модель
|
| 35 |
model_name = "BAAI/bge-m3"
|
| 36 |
logging.info(f"Загрузка модели {model_name}...")
|
| 37 |
model = SentenceTransformer(model_name)
|
| 38 |
logging.info("Модель загружена успешно.")
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Имена таблиц
|
| 41 |
embeddings_table = "movie_embeddings"
|
| 42 |
query_cache_table = "query_cache"
|
|
@@ -207,7 +216,7 @@ def process_batch(batch):
|
|
| 207 |
|
| 208 |
try:
|
| 209 |
for movie in batch:
|
| 210 |
-
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['
|
| 211 |
string_crc32 = calculate_crc32(embedding_string)
|
| 212 |
|
| 213 |
# Проверяем существующий эмбеддинг
|
|
@@ -289,6 +298,24 @@ def get_movie_embeddings(conn):
|
|
| 289 |
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
|
| 290 |
return movie_embeddings
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
def search_movies(query, top_k=20):
|
| 293 |
"""Выполняет поиск фильмов по запросу."""
|
| 294 |
global search_in_progress
|
|
@@ -332,23 +359,26 @@ def search_movies(query, top_k=20):
|
|
| 332 |
FROM {embeddings_table} m, query_embedding
|
| 333 |
ORDER BY similarity DESC
|
| 334 |
LIMIT %s
|
| 335 |
-
""", (query_crc32, top_k))
|
| 336 |
|
| 337 |
results = cur.fetchall()
|
| 338 |
-
logging.info(f"Найдено {len(results)} результатов поиска.")
|
| 339 |
except Exception as e:
|
| 340 |
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
|
| 341 |
results = []
|
| 342 |
|
|
|
|
|
|
|
|
|
|
| 343 |
output = ""
|
| 344 |
-
for movie_id,
|
| 345 |
# Находим фильм по ID
|
| 346 |
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
| 347 |
if movie:
|
| 348 |
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
|
| 349 |
-
output += f"<p><strong>Жанры:</strong> {movie['
|
| 350 |
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
|
| 351 |
-
output += f"<p><strong>Релевантность:</strong> {
|
| 352 |
output += "<hr>\n"
|
| 353 |
|
| 354 |
search_time = time.time() - start_time
|
|
|
|
| 12 |
import logging
|
| 13 |
from sklearn.preprocessing import normalize
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 16 |
|
| 17 |
# Настройка логирования
|
| 18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
| 32 |
"sslmode": "require"
|
| 33 |
}
|
| 34 |
|
| 35 |
+
# Загружаем модель эмбеддингов
|
| 36 |
model_name = "BAAI/bge-m3"
|
| 37 |
logging.info(f"Загрузка модели {model_name}...")
|
| 38 |
model = SentenceTransformer(model_name)
|
| 39 |
logging.info("Модель загружена успешно.")
|
| 40 |
|
| 41 |
+
# Загружаем модель реранкера
|
| 42 |
+
reranker_name = 'BAAI/bge-reranker-v2-m3'
|
| 43 |
+
logging.info(f"Загрузка модели реранкера {reranker_name}...")
|
| 44 |
+
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
|
| 45 |
+
reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
|
| 46 |
+
reranker_model.eval()
|
| 47 |
+
logging.info("Модель реранкера загружена успешно.")
|
| 48 |
+
|
| 49 |
# Имена таблиц
|
| 50 |
embeddings_table = "movie_embeddings"
|
| 51 |
query_cache_table = "query_cache"
|
|
|
|
| 216 |
|
| 217 |
try:
|
| 218 |
for movie in batch:
|
| 219 |
+
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
|
| 220 |
string_crc32 = calculate_crc32(embedding_string)
|
| 221 |
|
| 222 |
# Проверяем существующий эмбеддинг
|
|
|
|
| 298 |
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
|
| 299 |
return movie_embeddings
|
| 300 |
|
| 301 |
+
def rerank_results(query, results):
|
| 302 |
+
"""Переранжирует результаты поиска с помощью реранкера."""
|
| 303 |
+
pairs = []
|
| 304 |
+
movie_ids = []
|
| 305 |
+
for movie_id, _ in results:
|
| 306 |
+
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
| 307 |
+
if movie:
|
| 308 |
+
movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
|
| 309 |
+
pairs.append([query, movie_info])
|
| 310 |
+
movie_ids.append(movie_id)
|
| 311 |
+
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
| 314 |
+
scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float()
|
| 315 |
+
|
| 316 |
+
reranked_results = sorted(zip(movie_ids, scores.tolist()), key=lambda x: x[1], reverse=True)
|
| 317 |
+
return reranked_results
|
| 318 |
+
|
| 319 |
def search_movies(query, top_k=20):
|
| 320 |
"""Выполняет поиск фильмов по запросу."""
|
| 321 |
global search_in_progress
|
|
|
|
| 359 |
FROM {embeddings_table} m, query_embedding
|
| 360 |
ORDER BY similarity DESC
|
| 361 |
LIMIT %s
|
| 362 |
+
""", (query_crc32, top_k * 2)) # Увеличиваем лимит для последующего переранжирования
|
| 363 |
|
| 364 |
results = cur.fetchall()
|
| 365 |
+
logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
|
| 366 |
except Exception as e:
|
| 367 |
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
|
| 368 |
results = []
|
| 369 |
|
| 370 |
+
# Переранжируем результаты
|
| 371 |
+
reranked_results = rerank_results(query, results)
|
| 372 |
+
|
| 373 |
output = ""
|
| 374 |
+
for movie_id, score in reranked_results[:top_k]:
|
| 375 |
# Находим фильм по ID
|
| 376 |
movie = next((m for m in movies_data if m['id'] == movie_id), None)
|
| 377 |
if movie:
|
| 378 |
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
|
| 379 |
+
output += f"<p><strong>Жанры:</strong> {movie['genresList']}</p>\n"
|
| 380 |
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
|
| 381 |
+
output += f"<p><strong>Релевантность (reranker score):</strong> {score:.4f}</p>\n"
|
| 382 |
output += "<hr>\n"
|
| 383 |
|
| 384 |
search_time = time.time() - start_time
|