data-mining-tp / src /preprocess.py
Kacemath's picture
Deploy gradio movie revenue app with model and preprocessing
b490ee7
from __future__ import annotations
import math
import pickle
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
LANGUAGE_MAPPING = {"en": 1, "zh": 2, "ja": 3}
PREFIX_TO_FORM_KEY = {
"genres": "genres",
"production_companies": "production_companies",
"Keywords": "keywords",
"cast": "cast",
}
def load_model(model_path: str | Path) -> Any:
with Path(model_path).open("rb") as file:
return pickle.load(file)
def get_model_feature_names(model: Any) -> list[str]:
if not hasattr(model, "feature_names_in_"):
raise ValueError("Model does not expose feature_names_in_.")
return list(model.feature_names_in_)
def count_words(text: str | None) -> int:
if text is None:
return 0
normalized = str(text).strip()
if not normalized:
return 0
return len(normalized.split())
def runtime_category_code(runtime: float) -> int:
if runtime < 90:
return 0
if runtime < 120:
return 1
return 2
def parse_release_date(value: str | None) -> datetime:
if not value:
return datetime(2010, 1, 1)
try:
return datetime.strptime(value, "%Y-%m-%d")
except ValueError as exc:
raise ValueError("release_date must be in YYYY-MM-DD format.") from exc
def parse_feature_options(feature_names: list[str]) -> dict[str, list[str]]:
options: dict[str, set[str]] = {k: set() for k in PREFIX_TO_FORM_KEY}
for name in feature_names:
for prefix in options:
key = f"{prefix}_"
if name.startswith(key) and name != f"{prefix}_other":
options[prefix].add(name[len(key) :])
return {k: sorted(v) for k, v in options.items()}
def _to_float(value: Any, default: float = 0.0) -> float:
try:
if value is None:
return default
return float(value)
except (TypeError, ValueError):
return default
def _to_int(value: Any, default: int = 0) -> int:
try:
if value is None:
return default
return int(value)
except (TypeError, ValueError):
return default
def build_feature_row(form_data: dict[str, Any], feature_names: list[str]) -> pd.DataFrame:
row = {name: 0.0 for name in feature_names}
budget = max(_to_float(form_data.get("budget"), 0.0), 0.0)
popularity = max(_to_float(form_data.get("popularity"), 0.0), 0.0)
runtime = max(_to_float(form_data.get("runtime"), 0.0), 0.0)
release_date = parse_release_date(form_data.get("release_date"))
release_season = ((release_date.month % 12) + 3) // 3
title_text = str(form_data.get("title") or "")
tagline_text = str(form_data.get("tagline") or "")
overview_text = str(form_data.get("overview") or "")
values = {
"belongs_to_collection": _to_int(form_data.get("belongs_to_collection"), 0),
"homepage": _to_int(form_data.get("homepage"), 0),
"has_tagline": _to_int(form_data.get("has_tagline"), 1 if tagline_text.strip() else 0),
"original_language": LANGUAGE_MAPPING.get(str(form_data.get("original_language") or "").lower(), 0),
"runtime": runtime,
"num_of_cast": _to_float(form_data.get("num_of_cast"), 0.0),
"num_of_crew": _to_float(form_data.get("num_of_crew"), 0.0),
"gender_cast_1": _to_float(form_data.get("gender_cast_1"), 0.0),
"gender_cast_2": _to_float(form_data.get("gender_cast_2"), 0.0),
"count_cast_other": _to_float(form_data.get("count_cast_other"), 0.0),
"title_word_count": _to_float(form_data.get("title_word_count"), count_words(title_text)),
"tag_word_count": _to_float(form_data.get("tag_word_count"), count_words(tagline_text)),
"overview_word_count": _to_float(form_data.get("overview_word_count"), count_words(overview_text)),
"release_year": release_date.year,
"release_month": release_date.month,
"release_season": release_season,
"runtime_category": runtime_category_code(runtime),
"budget_log": math.log1p(budget),
"popularity_log": math.log1p(popularity),
}
for key, value in values.items():
if key in row:
row[key] = value
for prefix, form_key in PREFIX_TO_FORM_KEY.items():
selected = form_data.get(form_key) or []
if not isinstance(selected, list):
selected = [selected]
known = 0
for item in selected:
col = f"{prefix}_{item}"
if col in row:
row[col] = 1.0
known += 1
num_col = f"num_of_{prefix}"
if num_col in row:
row[num_col] = float(len(selected))
other_col = f"{prefix}_other"
if other_col in row:
row[other_col] = 1.0 if len(selected) > known else 0.0
df = pd.DataFrame([[row[name] for name in feature_names]], columns=feature_names)
return df.replace([np.inf, -np.inf], 0).fillna(0)
def predict_revenue(model: Any, form_data: dict[str, Any]) -> float:
feature_names = get_model_feature_names(model)
frame = build_feature_row(form_data, feature_names)
pred = model.predict(frame)[0]
return float(pred)