|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import time |
|
|
import requests |
|
|
import traceback |
|
|
from typing import Dict, Any, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TMDB_BASE = "https://api.themoviedb.org/3" |
|
|
TMDB_IMG_BASE = "https://image.tmdb.org/t/p/w500" |
|
|
DEFAULT_REGION = "KR" |
|
|
|
|
|
|
|
|
def _load_models(): |
|
|
|
|
|
sent = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment", device=-1) |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained("google/flan-t5-small") |
|
|
mdl = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") |
|
|
summer = pipeline("text2text-generation", model=mdl, tokenizer=tok, device=-1) |
|
|
|
|
|
try: |
|
|
emb = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
|
except Exception: |
|
|
emb = None |
|
|
return sent, summer, emb |
|
|
|
|
|
_sent, _summer, _emb = _load_models() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tmdb_get(api_key: str, path: str, params: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""GET with simple retry/backoff""" |
|
|
url = f"{TMDB_BASE}{path}" |
|
|
p = {"api_key": api_key, **params} |
|
|
last_err = None |
|
|
for attempt in range(3): |
|
|
try: |
|
|
r = requests.get(url, params=p, timeout=25) |
|
|
if r.status_code == 200: |
|
|
return r.json() |
|
|
last_err = f"{r.status_code} {r.text[:200]}" |
|
|
except Exception as e: |
|
|
last_err = str(e) |
|
|
time.sleep(0.7 * (attempt + 1)) |
|
|
raise RuntimeError(f"TMDb request failed: {last_err}") |
|
|
|
|
|
def get_provider_id(api_key: str, region: str, provider_name="Netflix") -> int: |
|
|
"""Fetch provider list for region; return provider_id for Netflix (fallback 8).""" |
|
|
data = tmdb_get(api_key, "/watch/providers/movie", {"watch_region": region}) |
|
|
for item in data.get("results", []): |
|
|
if str(item.get("provider_name","")).lower() == provider_name.lower(): |
|
|
return int(item["provider_id"]) |
|
|
return 8 |
|
|
|
|
|
def discover_quick(api_key: str, region: str, nfx_id: int, ctype="movie", |
|
|
sort_by="popularity.desc", page_limit=2) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Use TMDb Discover with Netflix provider filter. |
|
|
""" |
|
|
params = { |
|
|
"watch_region": region, |
|
|
"with_watch_providers": nfx_id, |
|
|
"sort_by": sort_by, |
|
|
"include_adult": False, |
|
|
"language": "ko-KR" |
|
|
} |
|
|
rows = [] |
|
|
for page in range(1, page_limit+1): |
|
|
data = tmdb_get(api_key, f"/discover/{ctype}", {**params, "page": page}) |
|
|
rows.extend([{"type": ctype, **r} for r in data.get("results", [])]) |
|
|
return rows |
|
|
|
|
|
def has_netflix_offer(api_key: str, content_type: str, tmdb_id: int, region: str, nfx_id: int) -> bool: |
|
|
"""Check if a specific item is offered on Netflix in the region.""" |
|
|
data = tmdb_get(api_key, f"/{content_type}/{tmdb_id}/watch/providers", {}) |
|
|
results = data.get("results", {}) |
|
|
info = results.get(region, {}) |
|
|
provs = info.get("flatrate", []) + info.get("ads", []) + info.get("free", []) |
|
|
return any(int(p.get("provider_id", -1)) == nfx_id for p in provs) |
|
|
|
|
|
def search_and_filter(api_key: str, query: str, region: str, nfx_id: int, |
|
|
content_types=("movie","tv"), max_pages_each=2, max_total=60) -> List[Dict[str,Any]]: |
|
|
""" |
|
|
1) Search movie/tv by query |
|
|
2) Validate Netflix provider for each |
|
|
""" |
|
|
out = [] |
|
|
for ctype in content_types: |
|
|
for page in range(1, max_pages_each+1): |
|
|
data = tmdb_get(api_key, f"/search/{ctype}", { |
|
|
"query": query, "page": page, "include_adult": False, "language": "ko-KR" |
|
|
}) |
|
|
for item in data.get("results", []): |
|
|
tmdb_id = item["id"] |
|
|
try: |
|
|
if has_netflix_offer(api_key, ctype, tmdb_id, region, nfx_id): |
|
|
out.append({"type": ctype, **item}) |
|
|
except Exception: |
|
|
pass |
|
|
if len(out) >= max_total: |
|
|
break |
|
|
if len(out) >= max_total: |
|
|
break |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _embed_texts(texts: List[str]) -> np.ndarray: |
|
|
if _emb is None or not texts: |
|
|
return np.zeros((len(texts), 384), dtype=np.float32) |
|
|
X = _emb.encode(texts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=False) |
|
|
return X |
|
|
|
|
|
def rank_by_query(items: List[Dict[str, Any]], query: str, topk: int = 10) -> List[Dict[str, Any]]: |
|
|
if not items: |
|
|
return [] |
|
|
if not query or not query.strip() or _emb is None: |
|
|
return items[:topk] |
|
|
texts = [] |
|
|
for it in items: |
|
|
title = it.get("name") or it.get("title") or "" |
|
|
overview = it.get("overview") or "" |
|
|
texts.append(f"{title}. {overview}") |
|
|
q = _emb.encode([query], normalize_embeddings=True, convert_to_numpy=True)[0].reshape(1, -1) |
|
|
X = _emb.encode(texts, normalize_embeddings=True, convert_to_numpy=True) |
|
|
sims = (q @ X.T)[0] |
|
|
idx = np.argsort(-sims)[:topk] |
|
|
return [items[i] for i in idx] |
|
|
|
|
|
def build_gallery(items: List[Dict[str, Any]]) -> Tuple[list, list]: |
|
|
""" |
|
|
Return (gallery_items, table_rows). Gallery expects list of [image, caption] |
|
|
""" |
|
|
gallery = [] |
|
|
rows = [] |
|
|
for it in items: |
|
|
title = it.get("name") or it.get("title") or "" |
|
|
overview = it.get("overview") or "" |
|
|
date = it.get("first_air_date") or it.get("release_date") or "" |
|
|
vote = it.get("vote_average") |
|
|
ctype = "๋๋ผ๋ง" if it.get("type") == "tv" else "์ํ" |
|
|
poster = it.get("poster_path") |
|
|
img = f"{TMDB_IMG_BASE}{poster}" if poster else None |
|
|
cap = f"{title} ({ctype})\nํ์ : {vote} | ๊ณต๊ฐ: {date}\n{overview[:120]}{'...' if len(overview)>120 else ''}" |
|
|
gallery.append([img, cap]) |
|
|
rows.append({"์ ๋ชฉ": title, "์ ํ": ctype, "๊ณต๊ฐ์ผ": date, "TMDbํ์ ": vote, "๊ฐ์": overview}) |
|
|
return gallery, rows |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STAR_MAP = {1:"๋งค์ฐ ๋ถ์ ", 2:"๋ถ์ ", 3:"์ค๋ฆฝ", 4:"๊ธ์ ", 5:"๋งค์ฐ ๊ธ์ "} |
|
|
|
|
|
def do_recommend(api_key_ui: str, query: str, region: str, mode: str, topk: int, |
|
|
sort_by: str, include_movie: bool, include_tv: bool): |
|
|
try: |
|
|
api_key = (api_key_ui or "").strip() or os.environ.get("TMDB_API_KEY", "").strip() |
|
|
if not api_key: |
|
|
return "TMDb API Key๋ฅผ ์
๋ ฅํ๊ฑฐ๋ ํ๊ฒฝ๋ณ์ TMDB_API_KEY๋ฅผ ์ค์ ํ์ธ์.", None, None |
|
|
nfx_id = get_provider_id(api_key, region, "Netflix") |
|
|
|
|
|
types = [] |
|
|
if include_movie: types.append("movie") |
|
|
if include_tv: types.append("tv") |
|
|
if not types: |
|
|
types = ["movie", "tv"] |
|
|
|
|
|
|
|
|
if mode == "๋น ๋ฅธ ์ถ์ฒ(Discover)": |
|
|
items = [] |
|
|
for t in types: |
|
|
items.extend(discover_quick(api_key, region, nfx_id, ctype=t, sort_by=sort_by, page_limit=2)) |
|
|
else: |
|
|
items = search_and_filter(api_key, query or "Netflix", region, nfx_id, |
|
|
content_types=tuple(types), max_pages_each=2, max_total=80) |
|
|
|
|
|
if not items: |
|
|
return f"์กฐ๊ฑด์ ๋ง๋ ๋ทํ๋ฆญ์ค({region}) ์ํ์ ์ฐพ์ง ๋ชปํ์ต๋๋ค.", None, None |
|
|
|
|
|
ranked = rank_by_query(items, query, topk=topk) |
|
|
gallery, rows = build_gallery(ranked) |
|
|
|
|
|
t = ranked[0] |
|
|
top_title = (t.get("name") or t.get("title") or "") |
|
|
pitch_prompt = ( |
|
|
"Summarize in Korean (1-2 sentences):\n" |
|
|
f"์ฌ์ฉ์ ์ทจํฅ/ํค์๋: {query}\n" |
|
|
f"์ํ: {top_title} / ๊ฐ์: {t.get('overview','')}" |
|
|
) |
|
|
pitch = _summer(pitch_prompt, max_new_tokens=80, do_sample=False)[0]["generated_text"] |
|
|
md = f"### โ
์ถ์ฒ ๊ฒฐ๊ณผ (Region={region}, Provider=Netflix)\n- Top 1: **{top_title}** โ {pitch}" |
|
|
return md, gallery, rows |
|
|
except Exception as e: |
|
|
return f"[์ค๋ฅ] {e}\n{traceback.format_exc()}", None, None |
|
|
|
|
|
def analyze_review(title: str, review: str): |
|
|
try: |
|
|
if not review or not review.strip(): |
|
|
return "๊ฐ์ํ์ ์
๋ ฅํด ์ฃผ์ธ์.", "" |
|
|
res = _sent(review)[0] |
|
|
stars = int(res["label"][0]) |
|
|
head = f"์์ธก ๋ณ์ : {stars} ({STAR_MAP.get(stars,'์ค๋ฆฝ')}) / ํ์ ๋: {float(res['score']):.3f}" |
|
|
summ = _summer( |
|
|
f"Summarize in Korean (1 sentence):\n์ ๋ชฉ: {title}\n๊ฐ์ํ: {review}", |
|
|
max_new_tokens=60, do_sample=False |
|
|
)[0]["generated_text"] |
|
|
return head, f"ํ์คํ: {summ}" |
|
|
except Exception as e: |
|
|
return f"[์ค๋ฅ] {e}\n{traceback.format_exc()}", "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## ๐ฟ ์ค์๊ฐ ๋ทํ๋ฆญ์ค(KR) ์ถ์ฒ & ๊ฐ์ํ โ TMDb API + ํฌ์คํฐ ์ด๋ฏธ์ง") |
|
|
|
|
|
with gr.Accordion("TMDb API ์ค์ ", open=True): |
|
|
api_key = gr.Textbox(label="TMDb API Key (UI ์
๋ ฅ์ ์ ํ, ๊ธฐ๋ณธ์ ํ๊ฒฝ๋ณ์ TMDB_API_KEY ์ฌ์ฉ)", type="password") |
|
|
region = gr.Dropdown(choices=["KR","US","JP","GB","DE","FR","ES"], value=DEFAULT_REGION, label="์ง์ญ(Watch Region)") |
|
|
|
|
|
with gr.Tab("์ถ์ฒ"): |
|
|
query = gr.Textbox(label="ํค์๋/๊ธฐ๋ถ(์ ํ)", placeholder="์) ๋ฐ๋ปํ ์ฑ์ฅ ๋๋ผ๋ง, ๋ฌด์์ด ํ๊ตญ ์ค๋ฆด๋ฌ", lines=2) |
|
|
with gr.Row(): |
|
|
mode = gr.Radio(choices=["๋น ๋ฅธ ์ถ์ฒ(Discover)", "ํค์๋ ๊ฒ์(์ ํ)"], value="๋น ๋ฅธ ์ถ์ฒ(Discover)", label="๊ฒ์ ๋ชจ๋") |
|
|
sort_by = gr.Dropdown(choices=["popularity.desc","vote_average.desc","release_date.desc"], value="popularity.desc", label="์ ๋ ฌ(Discover์ฉ)") |
|
|
topk = gr.Slider(3, 20, value=9, step=1, label="ํ์ ๊ฐ์") |
|
|
with gr.Row(): |
|
|
include_movie = gr.Checkbox(value=True, label="์ํ ํฌํจ") |
|
|
include_tv = gr.Checkbox(value=True, label="๋๋ผ๋ง ํฌํจ") |
|
|
btn = gr.Button("์ถ์ฒ ๋ฐ๊ธฐ") |
|
|
|
|
|
out_md = gr.Markdown() |
|
|
out_gallery = gr.Gallery(label="ํฌ์คํฐ ๊ฐค๋ฌ๋ฆฌ", columns=3, height="auto", allow_preview=True) |
|
|
out_table = gr.Dataframe(interactive=False, wrap=True) |
|
|
|
|
|
btn.click( |
|
|
do_recommend, |
|
|
inputs=[api_key, query, region, mode, topk, sort_by, include_movie, include_tv], |
|
|
outputs=[out_md, out_gallery, out_table] |
|
|
) |
|
|
|
|
|
with gr.Tab("๊ฐ์ํ ๋ถ์"): |
|
|
title = gr.Textbox(label="์ ๋ชฉ(์ ํ)", placeholder="์ถ์ฒ ํญ์์ ๋ณต์ฌํด ๋ถ์ฌ๋ฃ๊ธฐ") |
|
|
review = gr.Textbox(label="๊ฐ์ํ", lines=5, placeholder="์) ์ด๋ฐ์ ๋์ด์ง์ง๋ง, ๋ฐฐ์ฐ ์ฐ๊ธฐ๊ฐ ์๊ถ์ด์์.") |
|
|
b2 = gr.Button("๋ถ์") |
|
|
head = gr.Markdown() |
|
|
summ = gr.Markdown() |
|
|
b2.click(analyze_review, inputs=[title, review], outputs=[head, summ]) |
|
|
|
|
|
|
|
|
app = demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True, debug=True) |
|
|
|