testsaha24 / app.py
SaHa24D's picture
Update app.py
8cd0d6b verified
from transformers import BertTokenizerFast, BertForSequenceClassification
from sentence_transformers import SentenceTransformer
from torch.nn.functional import cosine_similarity
import torch
from datasets import load_dataset
import pandas as pd
import gradio as gr
# Загрузка датасета
def load_data():
dataset = load_dataset("wykonos/movies", split='train[:1400]')
df = pd.DataFrame(dataset)
# Обработка жанров
df['genres'] = df['genres'].apply(
lambda x: x.split('-') if isinstance(x, str) else []
)
# Исправление опечаток
genre_corrections = {
'hystery': 'mystery',
'adventura': 'adventure',
'nur': 'action'
}
df['genres'] = df['genres'].apply(
lambda x: [genre_corrections.get(g.lower(), g).capitalize() for g in x]
)
return df
df = load_data()
# Загрузка модели
model_name = "AventIQ-AI/bert-movie-recommendation-system"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
# Сопоставление меток модели с жанрами
genre_labels = [
"Action", "Adventure", "Animation", "Comedy", "Crime",
"Documentary", "Drama", "Family", "Fantasy", "History",
"Horror", "Music", "Mystery", "Romance", "Science Fiction",
"TV Movie", "Thriller", "War", "Western"
]
model.config.id2label = {i: label for i, label in enumerate(genre_labels)}
# Функция предсказания
def predict_genres(text, threshold=0.3):
inputs = tokenizer(
text,
max_length=128,
padding=True,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_indices = torch.where(probs > threshold)[1].tolist()
return list(set([model.config.id2label[i] for i in predicted_indices]))
# Модель для текстовых эмбеддингов
embedding_model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
# Кодируем описания фильмов
df["overview"] = df["overview"].fillna("") # если есть пропуски
vectors = embedding_model.encode(
df["overview"].tolist(),
convert_to_tensor=True,
show_progress_bar=True
)
df["overview_vector"] = list(vectors)
def recommend_movies(query, top_k=5):
try:
# 1. Предсказание жанров
genres = predict_genres(query)
# 2. Кодируем сам запрос
query_vector = embedding_model.encode([query], convert_to_tensor=True)[0]
# 3. Считаем косинусное сходство с описаниями фильмов
similarities = cosine_similarity(
query_vector.unsqueeze(0),
torch.stack(df["overview_vector"].tolist())
).squeeze(0)
df["similarity"] = similarities.cpu().numpy()
# 4. Объединяем семантику + жанры (опционально)
# Можно фильтровать по жанрам или просто повысить вес фильмам с совпадающими жанрами
def genre_score(row):
return any(g in row["genres"] for g in genres)
df["genre_boost"] = df.apply(genre_score, axis=1).astype(int)
df["total_score"] = df["similarity"] + df["genre_boost"] * 0.15 # вес жанра
# 5. Выбираем top_k
results = df.sort_values(
by="total_score", ascending=False
).head(top_k)
# 6. Форматируем ответ
output = []
for _, row in results.iterrows():
info = [
f"🎬 {row['title']}",
f"⭐ Рейтинг: {row['vote_average']}",
f"🎭 Жанры: {', '.join(row['genres'])}",
f"📅 Год: {row['release_date'][:4] if pd.notna(row['release_date']) else 'N/A'}",
f"📖 {row['overview'][:300]}..."
]
output.append("\n".join(info))
return "\n\n".join(output) if output else "Нет результатов"
except Exception as e:
return f"Ошибка: {str(e)}"
# Создание интерфейса
interface = gr.Interface(
fn=recommend_movies,
inputs=gr.Textbox(label="Опишите желаемый фильм"),
outputs=gr.Textbox(label="Рекомендации"),
examples=[
["Страшный фильм с привидениями"],
["Веселая комедия про студентов"],
["Фантастика с космическими битвами"]
],
title="🍿 AI Киносоветник"
)
if __name__ == "__main__":
interface.launch()