from __future__ import annotations import math import pickle from datetime import datetime from pathlib import Path from typing import Any import numpy as np import pandas as pd LANGUAGE_MAPPING = {"en": 1, "zh": 2, "ja": 3} PREFIX_TO_FORM_KEY = { "genres": "genres", "production_companies": "production_companies", "Keywords": "keywords", "cast": "cast", } def load_model(model_path: str | Path) -> Any: with Path(model_path).open("rb") as file: return pickle.load(file) def get_model_feature_names(model: Any) -> list[str]: if not hasattr(model, "feature_names_in_"): raise ValueError("Model does not expose feature_names_in_.") return list(model.feature_names_in_) def count_words(text: str | None) -> int: if text is None: return 0 normalized = str(text).strip() if not normalized: return 0 return len(normalized.split()) def runtime_category_code(runtime: float) -> int: if runtime < 90: return 0 if runtime < 120: return 1 return 2 def parse_release_date(value: str | None) -> datetime: if not value: return datetime(2010, 1, 1) try: return datetime.strptime(value, "%Y-%m-%d") except ValueError as exc: raise ValueError("release_date must be in YYYY-MM-DD format.") from exc def parse_feature_options(feature_names: list[str]) -> dict[str, list[str]]: options: dict[str, set[str]] = {k: set() for k in PREFIX_TO_FORM_KEY} for name in feature_names: for prefix in options: key = f"{prefix}_" if name.startswith(key) and name != f"{prefix}_other": options[prefix].add(name[len(key) :]) return {k: sorted(v) for k, v in options.items()} def _to_float(value: Any, default: float = 0.0) -> float: try: if value is None: return default return float(value) except (TypeError, ValueError): return default def _to_int(value: Any, default: int = 0) -> int: try: if value is None: return default return int(value) except (TypeError, ValueError): return default def build_feature_row(form_data: dict[str, Any], feature_names: list[str]) -> pd.DataFrame: row = {name: 0.0 for name in feature_names} budget = max(_to_float(form_data.get("budget"), 0.0), 0.0) popularity = max(_to_float(form_data.get("popularity"), 0.0), 0.0) runtime = max(_to_float(form_data.get("runtime"), 0.0), 0.0) release_date = parse_release_date(form_data.get("release_date")) release_season = ((release_date.month % 12) + 3) // 3 title_text = str(form_data.get("title") or "") tagline_text = str(form_data.get("tagline") or "") overview_text = str(form_data.get("overview") or "") values = { "belongs_to_collection": _to_int(form_data.get("belongs_to_collection"), 0), "homepage": _to_int(form_data.get("homepage"), 0), "has_tagline": _to_int(form_data.get("has_tagline"), 1 if tagline_text.strip() else 0), "original_language": LANGUAGE_MAPPING.get(str(form_data.get("original_language") or "").lower(), 0), "runtime": runtime, "num_of_cast": _to_float(form_data.get("num_of_cast"), 0.0), "num_of_crew": _to_float(form_data.get("num_of_crew"), 0.0), "gender_cast_1": _to_float(form_data.get("gender_cast_1"), 0.0), "gender_cast_2": _to_float(form_data.get("gender_cast_2"), 0.0), "count_cast_other": _to_float(form_data.get("count_cast_other"), 0.0), "title_word_count": _to_float(form_data.get("title_word_count"), count_words(title_text)), "tag_word_count": _to_float(form_data.get("tag_word_count"), count_words(tagline_text)), "overview_word_count": _to_float(form_data.get("overview_word_count"), count_words(overview_text)), "release_year": release_date.year, "release_month": release_date.month, "release_season": release_season, "runtime_category": runtime_category_code(runtime), "budget_log": math.log1p(budget), "popularity_log": math.log1p(popularity), } for key, value in values.items(): if key in row: row[key] = value for prefix, form_key in PREFIX_TO_FORM_KEY.items(): selected = form_data.get(form_key) or [] if not isinstance(selected, list): selected = [selected] known = 0 for item in selected: col = f"{prefix}_{item}" if col in row: row[col] = 1.0 known += 1 num_col = f"num_of_{prefix}" if num_col in row: row[num_col] = float(len(selected)) other_col = f"{prefix}_other" if other_col in row: row[other_col] = 1.0 if len(selected) > known else 0.0 df = pd.DataFrame([[row[name] for name in feature_names]], columns=feature_names) return df.replace([np.inf, -np.inf], 0).fillna(0) def predict_revenue(model: Any, form_data: dict[str, Any]) -> float: feature_names = get_model_feature_names(model) frame = build_feature_row(form_data, feature_names) pred = model.predict(frame)[0] return float(pred)