yigitcanozdemir
commited on
Commit
·
6b33fac
1
Parent(s):
95f8049
Optimizated search
Browse files- components/gradio_ui.py +76 -6
- components/similarity.py +5 -4
- components/tmdb_api.py +54 -0
- config.py +3 -1
- models/recommendation_engine.py +14 -5
components/gradio_ui.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
|
| 4 |
def create_interface(engine):
|
| 5 |
def get_recommendations_text(query):
|
| 6 |
-
"""Wrapper function to safely get only the text result"""
|
| 7 |
try:
|
| 8 |
result = engine.get_recommendations(query)
|
| 9 |
if isinstance(result, tuple) and len(result) >= 1:
|
|
@@ -13,8 +11,68 @@ def create_interface(engine):
|
|
| 13 |
except Exception as e:
|
| 14 |
return f"❌ Error: {str(e)}"
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
with gr.Blocks(
|
| 17 |
-
theme=gr.themes.Soft(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
) as demo:
|
| 19 |
gr.Markdown("# 🎬 TV-Series and Movie Recommend")
|
| 20 |
|
|
@@ -28,19 +86,31 @@ def create_interface(engine):
|
|
| 28 |
|
| 29 |
search_btn = gr.Button("🔍 Search", variant="primary")
|
| 30 |
|
| 31 |
-
with gr.Column(scale=
|
| 32 |
results_text = gr.Textbox(
|
| 33 |
-
label="
|
| 34 |
lines=20,
|
| 35 |
max_lines=25,
|
| 36 |
show_copy_button=True,
|
| 37 |
interactive=False,
|
| 38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
search_btn.click(
|
| 41 |
fn=get_recommendations_text,
|
| 42 |
inputs=[query_input],
|
| 43 |
outputs=[results_text],
|
| 44 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
return demo
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
|
|
|
| 3 |
def create_interface(engine):
|
| 4 |
def get_recommendations_text(query):
|
|
|
|
| 5 |
try:
|
| 6 |
result = engine.get_recommendations(query)
|
| 7 |
if isinstance(result, tuple) and len(result) >= 1:
|
|
|
|
| 11 |
except Exception as e:
|
| 12 |
return f"❌ Error: {str(e)}"
|
| 13 |
|
| 14 |
+
def get_thumbnails_html(query):
|
| 15 |
+
try:
|
| 16 |
+
result = engine.get_recommendations(query)
|
| 17 |
+
if isinstance(result, tuple) and len(result) >= 1:
|
| 18 |
+
search_results = engine.get_recommendations(query)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
thumbnails_html = []
|
| 22 |
+
thumbnails_html.append("""
|
| 23 |
+
<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 15px; padding: 20px; max-height: 600px; overflow-y: auto;">
|
| 24 |
+
""")
|
| 25 |
+
|
| 26 |
+
thumbnails_html.append("""
|
| 27 |
+
<div style="grid-column: 1 / -1; text-align: center; padding: 20px; color: #666;">
|
| 28 |
+
Thumbnails will appear here when poster URLs are available
|
| 29 |
+
</div>
|
| 30 |
+
""")
|
| 31 |
+
|
| 32 |
+
thumbnails_html.append("</div>")
|
| 33 |
+
return "".join(thumbnails_html)
|
| 34 |
+
|
| 35 |
+
except Exception as e:
|
| 36 |
+
return f"<div style='color: red; padding: 20px;'>❌ Error: {str(e)}</div>"
|
| 37 |
+
|
| 38 |
+
def get_thumbnails_from_results(query):
|
| 39 |
+
"""Get thumbnails from search results"""
|
| 40 |
+
try:
|
| 41 |
+
formatted_results, df_results = engine.get_recommendations(query)
|
| 42 |
+
|
| 43 |
+
html_parts = []
|
| 44 |
+
html_parts.append("""
|
| 45 |
+
<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(150px, 1fr)); gap: 15px; padding: 20px; max-height: 600px; overflow-y: auto; background: #f8f9fa; border-radius: 8px;">
|
| 46 |
+
""")
|
| 47 |
+
|
| 48 |
+
for i in range(10):
|
| 49 |
+
html_parts.append(f"""
|
| 50 |
+
<div style="position: relative; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); transition: transform 0.2s; cursor: pointer;"
|
| 51 |
+
onmouseover="this.style.transform='scale(1.05)'"
|
| 52 |
+
onmouseout="this.style.transform='scale(1)'">
|
| 53 |
+
<div style="width: 100%; height: 200px; background: #ddd; display: flex; align-items: center; justify-content: center; color: #666; font-size: 12px;">
|
| 54 |
+
Poster {i+1}
|
| 55 |
+
</div>
|
| 56 |
+
<div style="position: absolute; bottom: 0; left: 0; right: 0; background: linear-gradient(transparent, rgba(0,0,0,0.7)); color: white; padding: 8px; font-size: 12px; text-align: center;">
|
| 57 |
+
Movie Title {i+1}
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
""")
|
| 61 |
+
|
| 62 |
+
html_parts.append("</div>")
|
| 63 |
+
return "".join(html_parts)
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return f"<div style='color: red; padding: 20px;'>❌ Error: {str(e)}</div>"
|
| 67 |
+
|
| 68 |
with gr.Blocks(
|
| 69 |
+
theme=gr.themes.Soft(),
|
| 70 |
+
title="TV-Series and Movie Recommend",
|
| 71 |
+
css="""
|
| 72 |
+
.gradio-container {
|
| 73 |
+
max-width: 1200px !important;
|
| 74 |
+
}
|
| 75 |
+
"""
|
| 76 |
) as demo:
|
| 77 |
gr.Markdown("# 🎬 TV-Series and Movie Recommend")
|
| 78 |
|
|
|
|
| 86 |
|
| 87 |
search_btn = gr.Button("🔍 Search", variant="primary")
|
| 88 |
|
| 89 |
+
with gr.Column(scale=1):
|
| 90 |
results_text = gr.Textbox(
|
| 91 |
+
label="Detailed Results",
|
| 92 |
lines=20,
|
| 93 |
max_lines=25,
|
| 94 |
show_copy_button=True,
|
| 95 |
interactive=False,
|
| 96 |
)
|
| 97 |
+
|
| 98 |
+
with gr.Column(scale=1):
|
| 99 |
+
thumbnails_display = gr.HTML(
|
| 100 |
+
label="Movie Posters",
|
| 101 |
+
value="<div style='text-align: center; padding: 40px; color: #666;'>Movie thumbnails will appear here</div>"
|
| 102 |
+
)
|
| 103 |
|
| 104 |
search_btn.click(
|
| 105 |
fn=get_recommendations_text,
|
| 106 |
inputs=[query_input],
|
| 107 |
outputs=[results_text],
|
| 108 |
)
|
| 109 |
+
|
| 110 |
+
search_btn.click(
|
| 111 |
+
fn=get_thumbnails_from_results,
|
| 112 |
+
inputs=[query_input],
|
| 113 |
+
outputs=[thumbnails_display],
|
| 114 |
+
)
|
| 115 |
|
| 116 |
+
return demo
|
components/similarity.py
CHANGED
|
@@ -22,7 +22,7 @@ class SimilarityCalculator:
|
|
| 22 |
}
|
| 23 |
|
| 24 |
start_time = time.time()
|
| 25 |
-
|
| 26 |
query_embedding = self.model.encode([query])
|
| 27 |
query_embedding = torch.tensor(query_embedding, dtype=torch.float32)
|
| 28 |
|
|
@@ -35,7 +35,7 @@ class SimilarityCalculator:
|
|
| 35 |
similarities = similarities[0]
|
| 36 |
|
| 37 |
hybrid_scores = self._calculate_hybrid_score(
|
| 38 |
-
similarities, filtered_data, similarity_weight=0.
|
| 39 |
)
|
| 40 |
|
| 41 |
top_indices = (
|
|
@@ -50,6 +50,7 @@ class SimilarityCalculator:
|
|
| 50 |
row = filtered_data.iloc[idx]
|
| 51 |
|
| 52 |
result = {
|
|
|
|
| 53 |
"title": row["primaryTitle"],
|
| 54 |
"type": row["titleType"],
|
| 55 |
"year": row["startYear"],
|
|
@@ -82,8 +83,8 @@ class SimilarityCalculator:
|
|
| 82 |
self,
|
| 83 |
similarities: torch.Tensor,
|
| 84 |
data: pd.DataFrame,
|
| 85 |
-
similarity_weight: float = 0.
|
| 86 |
-
rating_weight: float = 0.
|
| 87 |
) -> torch.Tensor:
|
| 88 |
|
| 89 |
sim_normalized = (similarities - similarities.min()) / (
|
|
|
|
| 22 |
}
|
| 23 |
|
| 24 |
start_time = time.time()
|
| 25 |
+
print(f"🔍 Calculating similarity for query: {query}")
|
| 26 |
query_embedding = self.model.encode([query])
|
| 27 |
query_embedding = torch.tensor(query_embedding, dtype=torch.float32)
|
| 28 |
|
|
|
|
| 35 |
similarities = similarities[0]
|
| 36 |
|
| 37 |
hybrid_scores = self._calculate_hybrid_score(
|
| 38 |
+
similarities, filtered_data, similarity_weight=0.8, rating_weight=0.2
|
| 39 |
)
|
| 40 |
|
| 41 |
top_indices = (
|
|
|
|
| 50 |
row = filtered_data.iloc[idx]
|
| 51 |
|
| 52 |
result = {
|
| 53 |
+
"tconst": row["tconst"],
|
| 54 |
"title": row["primaryTitle"],
|
| 55 |
"type": row["titleType"],
|
| 56 |
"year": row["startYear"],
|
|
|
|
| 83 |
self,
|
| 84 |
similarities: torch.Tensor,
|
| 85 |
data: pd.DataFrame,
|
| 86 |
+
similarity_weight: float = 0.8,
|
| 87 |
+
rating_weight: float = 0.2,
|
| 88 |
) -> torch.Tensor:
|
| 89 |
|
| 90 |
sim_normalized = (similarities - similarities.min()) / (
|
components/tmdb_api.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
from config import Config
|
| 3 |
+
|
| 4 |
+
class TMDBApi:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.config = Config()
|
| 7 |
+
self.base_url = self.config.TMDB_BASE_URL
|
| 8 |
+
self.api_key = self.config.TMDB_API_KEY
|
| 9 |
+
self.image_base_url = self.config.TMDB_IMAGE_BASE_URL
|
| 10 |
+
|
| 11 |
+
def get_poster_by_imdb_id(self, imdb_id: str):
|
| 12 |
+
try:
|
| 13 |
+
if not imdb_id.startswith('tt'):
|
| 14 |
+
imdb_id = f"tt{imdb_id}"
|
| 15 |
+
|
| 16 |
+
endpoint = f"{self.base_url}/find/{imdb_id}"
|
| 17 |
+
params = {
|
| 18 |
+
"api_key": self.api_key,
|
| 19 |
+
"external_source": "tconst"
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
response = requests.get(endpoint, params=params)
|
| 23 |
+
response.raise_for_status()
|
| 24 |
+
|
| 25 |
+
data = response.json()
|
| 26 |
+
|
| 27 |
+
poster_path = None
|
| 28 |
+
if data.get("movie_results"):
|
| 29 |
+
poster_path = data["movie_results"][0].get("poster_path")
|
| 30 |
+
elif data.get("tv_results"):
|
| 31 |
+
poster_path = data["tv_results"][0].get("poster_path")
|
| 32 |
+
|
| 33 |
+
if poster_path:
|
| 34 |
+
return f"{self.image_base_url}{poster_path}"
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"❌ TMDB API Error for IMDB ID {imdb_id}: {str(e)}")
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
def get_multiple_posters_by_imdb(self, items: list):
|
| 42 |
+
results = []
|
| 43 |
+
for item in items:
|
| 44 |
+
imdb_id = item.get('tconst')
|
| 45 |
+
|
| 46 |
+
if imdb_id:
|
| 47 |
+
poster_url = self.get_poster_by_imdb_id(imdb_id)
|
| 48 |
+
item['poster_url'] = poster_url
|
| 49 |
+
else:
|
| 50 |
+
item['poster_url'] = None
|
| 51 |
+
|
| 52 |
+
results.append(item)
|
| 53 |
+
|
| 54 |
+
return results
|
config.py
CHANGED
|
@@ -39,7 +39,9 @@ GENRE_LIST = Literal[
|
|
| 39 |
|
| 40 |
class Config:
|
| 41 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
|
| 44 |
DATA_FILE = "data/demo_data.parquet"
|
| 45 |
|
|
|
|
| 39 |
|
| 40 |
class Config:
|
| 41 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 42 |
+
TMDB_API_KEY = os.getenv("TMDB_API_KEY")
|
| 43 |
+
TMDB_BASE_URL = "https://api.themoviedb.org/3"
|
| 44 |
+
TMDB_IMAGE_BASE_URL = "https://image.tmdb.org/t/p/w500"
|
| 45 |
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"
|
| 46 |
DATA_FILE = "data/demo_data.parquet"
|
| 47 |
|
models/recommendation_engine.py
CHANGED
|
@@ -6,7 +6,7 @@ from models.pydantic_schemas import Features
|
|
| 6 |
from components.similarity import SimilarityCalculator
|
| 7 |
from components.filters import MovieFilter
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
-
|
| 10 |
|
| 11 |
class RecommendationEngine:
|
| 12 |
def __init__(self):
|
|
@@ -19,6 +19,7 @@ class RecommendationEngine:
|
|
| 19 |
|
| 20 |
self.similarity_calc = SimilarityCalculator(self.model)
|
| 21 |
self.filter = MovieFilter()
|
|
|
|
| 22 |
|
| 23 |
print(f"✅ Recommendation engine initialized with {len(self.data)} items.")
|
| 24 |
|
|
@@ -33,8 +34,17 @@ class RecommendationEngine:
|
|
| 33 |
filtered_data = self.filter.apply_filters(self.data, features)
|
| 34 |
|
| 35 |
search_results = self.similarity_calc.calculate_similarity(
|
| 36 |
-
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
formatted_results = self._format_results(search_results)
|
| 40 |
|
|
@@ -44,14 +54,13 @@ class RecommendationEngine:
|
|
| 44 |
return f"❌ Error: {str(e)}", None
|
| 45 |
|
| 46 |
def _parse_user_query(self, query: str) -> Features:
|
| 47 |
-
"""GPT ile kullanıcı sorgusu parse et"""
|
| 48 |
try:
|
| 49 |
response = self.client.beta.chat.completions.parse(
|
| 50 |
model="gpt-4o-mini",
|
| 51 |
messages=[
|
| 52 |
{
|
| 53 |
"role": "system",
|
| 54 |
-
"content": "You are an AI that converts user requests into structured movie/TV-series features.
|
| 55 |
},
|
| 56 |
{"role": "user", "content": query},
|
| 57 |
],
|
|
@@ -127,5 +136,5 @@ class RecommendationEngine:
|
|
| 127 |
"Overview": result["overview"],
|
| 128 |
}
|
| 129 |
)
|
| 130 |
-
|
| 131 |
return pd.DataFrame(df_data)
|
|
|
|
| 6 |
from components.similarity import SimilarityCalculator
|
| 7 |
from components.filters import MovieFilter
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
+
from components.tmdb_api import TMDBApi
|
| 10 |
|
| 11 |
class RecommendationEngine:
|
| 12 |
def __init__(self):
|
|
|
|
| 19 |
|
| 20 |
self.similarity_calc = SimilarityCalculator(self.model)
|
| 21 |
self.filter = MovieFilter()
|
| 22 |
+
self.tmdb_api = TMDBApi()
|
| 23 |
|
| 24 |
print(f"✅ Recommendation engine initialized with {len(self.data)} items.")
|
| 25 |
|
|
|
|
| 34 |
filtered_data = self.filter.apply_filters(self.data, features)
|
| 35 |
|
| 36 |
search_results = self.similarity_calc.calculate_similarity(
|
| 37 |
+
features.themes, filtered_data, top_k
|
| 38 |
)
|
| 39 |
+
if search_results["results"]:
|
| 40 |
+
print(f"🔍 First result keys: {search_results['results'][0].keys()}")
|
| 41 |
+
|
| 42 |
+
for i, result in enumerate(search_results["results"]):
|
| 43 |
+
print(f"🔍 Result {i}: tconst = {result.get('tconst', 'NOT FOUND')}")
|
| 44 |
+
|
| 45 |
+
search_results["results"] = self.tmdb_api.get_multiple_posters_by_imdb(
|
| 46 |
+
search_results["results"]
|
| 47 |
+
)
|
| 48 |
|
| 49 |
formatted_results = self._format_results(search_results)
|
| 50 |
|
|
|
|
| 54 |
return f"❌ Error: {str(e)}", None
|
| 55 |
|
| 56 |
def _parse_user_query(self, query: str) -> Features:
|
|
|
|
| 57 |
try:
|
| 58 |
response = self.client.beta.chat.completions.parse(
|
| 59 |
model="gpt-4o-mini",
|
| 60 |
messages=[
|
| 61 |
{
|
| 62 |
"role": "system",
|
| 63 |
+
"content": "You are an AI that converts user requests into structured movie/TV-series features. ONLY extract genres that are explicitly mentioned by the user. Do not infer or add additional genres unless clearly stated.",
|
| 64 |
},
|
| 65 |
{"role": "user", "content": query},
|
| 66 |
],
|
|
|
|
| 136 |
"Overview": result["overview"],
|
| 137 |
}
|
| 138 |
)
|
| 139 |
+
print(df_data)
|
| 140 |
return pd.DataFrame(df_data)
|