SaHa24D commited on
Commit
fe797ab
·
verified ·
1 Parent(s): d545633

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +105 -101
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,137 +1,141 @@
1
- import gradio as gr
2
- import pandas as pd
 
3
  import torch
4
- import numpy as np
5
- from transformers import BertForSequenceClassification, BertTokenizerFast
6
  from datasets import load_dataset
7
- import ast # Для преобразования строковых представлений списков
8
-
9
- # Убираем предупреждение о симлинках
10
- import os
11
- os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
12
 
13
- # Загрузка и подготовка данных
14
- def load_movie_data():
15
- try:
16
- dataset = load_dataset("wykonos/movies")
17
- df = pd.DataFrame(dataset['train'])
18
-
19
- # Преобразование строковых представлений в списки
20
- def parse_list(x):
21
- try:
22
- return [item['name'] for item in ast.literal_eval(x)]
23
- except:
24
- return []
25
-
26
- # Обработка колонок
27
- df['genres'] = df['genres'].apply(parse_list)
28
- df['production_companies'] = df['production_companies'].apply(parse_list)
29
- df['keywords'] = df['keywords'].apply(parse_list)
30
-
31
- # Выбор нужных колонок
32
- df = df[[
33
- 'id', 'title', 'genres', 'overview',
34
- 'release_date', 'vote_average', 'poster_path'
35
- ]].drop_duplicates(subset=['id'])
36
-
37
- return df
38
 
39
- except Exception as e:
40
- print(f"Ошибка загрузки данных: {str(e)}")
41
- return pd.DataFrame()
42
-
43
- # Загрузка данных
44
- df = load_movie_data()
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- if df.empty:
47
- raise RuntimeError("Не удалось загрузить данные. Проверьте структуру датасета.")
48
 
49
- # Загрузка модели и токенизатора
50
  model_name = "AventIQ-AI/bert-movie-recommendation-system"
51
- try:
52
- tokenizer = BertTokenizerFast.from_pretrained(model_name)
53
- model = BertForSequenceClassification.from_pretrained(model_name)
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
- model = model.to(device)
56
- except Exception as e:
57
- raise RuntimeError(f"Ошибка загрузки модели: {str(e)}")
58
-
59
- def predict_genres(text, threshold=0.5):
 
 
 
 
 
 
60
  inputs = tokenizer(
61
- text,
62
- padding=True,
63
- truncation=True,
64
- max_length=128,
65
  return_tensors="pt"
66
- ).to(device)
67
 
68
  with torch.no_grad():
69
  outputs = model(**inputs)
70
 
71
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
72
- predicted_indices = np.where(probs.cpu().numpy() > threshold)[1]
73
 
74
- genre_labels = list(model.config.id2label.values())
75
- predicted_genres = [genre_labels[i] for i in predicted_indices]
76
-
77
- return list(set(predicted_genres))
 
 
 
 
 
 
 
 
 
78
 
79
- def recommend_movies(query, top_k=10):
 
80
  try:
 
81
  genres = predict_genres(query)
82
- mask = df['genres'].apply(lambda x: any(g in genres for g in x))
83
- filtered = df[mask]
84
 
85
- # Сортировка по рейтингу и популярности
86
- results = filtered.sort_values(
87
- by=['vote_average', 'release_date'],
88
- ascending=[False, False]
89
- ).head(top_k)
 
 
 
 
 
 
 
 
 
 
90
 
 
 
 
 
 
 
 
 
 
91
  output = []
92
  for _, row in results.iterrows():
93
- movie_info = [
94
- f"🎬 **{row['title']}** ({row['release_date'][:4]})",
95
  f"⭐ Рейтинг: {row['vote_average']}",
96
- f"📀 Жанры: {', '.join(row['genres'][:3])}",
 
 
97
  ]
98
-
99
- if row['overview']:
100
- movie_info.append(f"📖 Описание: {row['overview'][:200]}...")
101
-
102
- if row['poster_path']:
103
- movie_info.append(f"🖼️ Постер: https://image.tmdb.org/t/p/w500{row['poster_path']}")
104
-
105
- output.append("\n".join(movie_info) + "\n" + "-"*50)
106
 
107
- return "\n\n".join(output) if len(output) > 0 else "Нет подходящих фильмов"
108
-
109
  except Exception as e:
110
  return f"Ошибка: {str(e)}"
111
 
112
- # Создание интерфейса Gradio с HTML-стилями
113
- css = """
114
- .gradio-container {background: #f0f2f6}
115
- h1 {text-align: center; color: #2d3436}
116
- """
117
 
 
118
  interface = gr.Interface(
119
  fn=recommend_movies,
120
- inputs=gr.Textbox(
121
- label="🎥 Введите запрос",
122
- placeholder="Пример: Космическая опера с эпическими битвами..."
123
- ),
124
- outputs=gr.Markdown(label="🎬 Результаты поиска"),
125
- title="🍿 AI-Кинотеатр: Персональные рекомендации",
126
- description="Система рекомендаций фильмов на основе глубокого обучения",
127
  examples=[
128
- ["Комедия с неожиданным сюжетом"],
129
- ["Драма о взрослении подростков"],
130
- ["Фантастический боевик с роботами"]
131
  ],
132
- css=css,
133
- allow_flagging="never"
134
  )
135
 
136
  if __name__ == "__main__":
137
- interface.launch()
 
1
+ from transformers import BertTokenizerFast, BertForSequenceClassification
2
+ from sentence_transformers import SentenceTransformer
3
+ from torch.nn.functional import cosine_similarity
4
  import torch
 
 
5
  from datasets import load_dataset
6
+ import pandas as pd
7
+ import gradio as gr
 
 
 
8
 
9
+ # Загрузка датасета
10
+ def load_data():
11
+ dataset = load_dataset("wykonos/movies", split='train[:400]')
12
+ df = pd.DataFrame(dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Обработка жанров
15
+ df['genres'] = df['genres'].apply(
16
+ lambda x: x.split('-') if isinstance(x, str) else []
17
+ )
18
+
19
+ # Исправление опечаток
20
+ genre_corrections = {
21
+ 'hystery': 'mystery',
22
+ 'adventura': 'adventure',
23
+ 'nur': 'action'
24
+ }
25
+
26
+ df['genres'] = df['genres'].apply(
27
+ lambda x: [genre_corrections.get(g.lower(), g).capitalize() for g in x]
28
+ )
29
+
30
+ return df
31
 
32
+ df = load_data()
 
33
 
34
+ # Загрузка модели
35
  model_name = "AventIQ-AI/bert-movie-recommendation-system"
36
+ tokenizer = BertTokenizerFast.from_pretrained(model_name)
37
+ model = BertForSequenceClassification.from_pretrained(model_name)
38
+
39
+ # Сопоставление меток модели с жанрами
40
+ genre_labels = [
41
+ "Action", "Adventure", "Animation", "Comedy", "Crime",
42
+ "Documentary", "Drama", "Family", "Fantasy", "History",
43
+ "Horror", "Music", "Mystery", "Romance", "Science Fiction",
44
+ "TV Movie", "Thriller", "War", "Western"
45
+ ]
46
+
47
+ model.config.id2label = {i: label for i, label in enumerate(genre_labels)}
48
+
49
+ # Функция предсказания
50
+ def predict_genres(text, threshold=0.3):
51
  inputs = tokenizer(
52
+ text,
53
+ max_length=128,
54
+ padding=True,
55
+ truncation=True,
56
  return_tensors="pt"
57
+ )
58
 
59
  with torch.no_grad():
60
  outputs = model(**inputs)
61
 
62
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
63
+ predicted_indices = torch.where(probs > threshold)[1].tolist()
64
 
65
+ return list(set([model.config.id2label[i] for i in predicted_indices]))
66
+ # Модель для текстовых эмбеддингов
67
+ embedding_model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
68
+
69
+ # Кодируем описания фильмов
70
+ df["overview"] = df["overview"].fillna("") # если есть пропуски
71
+ vectors = embedding_model.encode(
72
+ df["overview"].tolist(),
73
+ convert_to_tensor=True,
74
+ show_progress_bar=True
75
+ )
76
+
77
+ df["overview_vector"] = list(vectors)
78
 
79
+
80
+ def recommend_movies(query, top_k=5):
81
  try:
82
+ # 1. Предсказание жанров
83
  genres = predict_genres(query)
 
 
84
 
85
+ # 2. Кодируем сам запрос
86
+ query_vector = embedding_model.encode([query], convert_to_tensor=True)[0]
87
+
88
+ # 3. Считаем косинусное сходство с описаниями фильмов
89
+ similarities = cosine_similarity(
90
+ query_vector.unsqueeze(0),
91
+ torch.stack(df["overview_vector"].tolist())
92
+ ).squeeze(0)
93
+
94
+ df["similarity"] = similarities.cpu().numpy()
95
+
96
+ # 4. Объединяем семантику + жанры (опционально)
97
+ # Можно фильтровать по жанрам или просто повысить вес фильмам с совпадающими жанрами
98
+ def genre_score(row):
99
+ return any(g in row["genres"] for g in genres)
100
 
101
+ df["genre_boost"] = df.apply(genre_score, axis=1).astype(int)
102
+ df["total_score"] = df["similarity"] + df["genre_boost"] * 0.15 # вес жанра
103
+
104
+ # 5. Выбираем top_k
105
+ results = df.sort_values(
106
+ by="total_score", ascending=False
107
+ ).head(top_k)
108
+
109
+ # 6. Форматируем ответ
110
  output = []
111
  for _, row in results.iterrows():
112
+ info = [
113
+ f"🎬 {row['title']}",
114
  f"⭐ Рейтинг: {row['vote_average']}",
115
+ f"🎭 Жанры: {', '.join(row['genres'])}",
116
+ f"📅 Год: {row['release_date'][:4] if pd.notna(row['release_date']) else 'N/A'}",
117
+ f"📖 {row['overview'][:300]}..."
118
  ]
119
+ output.append("\n".join(info))
 
 
 
 
 
 
 
120
 
121
+ return "\n\n" + "\n\n".join(output) if output else "Нет результатов"
122
+
123
  except Exception as e:
124
  return f"Ошибка: {str(e)}"
125
 
 
 
 
 
 
126
 
127
+ # Создание интерфейса
128
  interface = gr.Interface(
129
  fn=recommend_movies,
130
+ inputs=gr.Textbox(label="Опишите желаемый фильм"),
131
+ outputs=gr.Textbox(label="Рекомендации"),
 
 
 
 
 
132
  examples=[
133
+ ["Страшный фильм с привидениями"],
134
+ ["Веселая комедия про студентов"],
135
+ ["Фантастика с космическими битвами"]
136
  ],
137
+ title="🍿 AI Киносоветник"
 
138
  )
139
 
140
  if __name__ == "__main__":
141
+ interface.launch(server_name="localhost", server_port=7860)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  sentence-transformers
2
  faiss-cpu
3
  gradio
4
- pandas
 
 
1
  sentence-transformers
2
  faiss-cpu
3
  gradio
4
+ pandas
5
+ datasets