Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import math | |
| import pickle | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import pandas as pd | |
| LANGUAGE_MAPPING = {"en": 1, "zh": 2, "ja": 3} | |
| PREFIX_TO_FORM_KEY = { | |
| "genres": "genres", | |
| "production_companies": "production_companies", | |
| "Keywords": "keywords", | |
| "cast": "cast", | |
| } | |
| def load_model(model_path: str | Path) -> Any: | |
| with Path(model_path).open("rb") as file: | |
| return pickle.load(file) | |
| def get_model_feature_names(model: Any) -> list[str]: | |
| if not hasattr(model, "feature_names_in_"): | |
| raise ValueError("Model does not expose feature_names_in_.") | |
| return list(model.feature_names_in_) | |
| def count_words(text: str | None) -> int: | |
| if text is None: | |
| return 0 | |
| normalized = str(text).strip() | |
| if not normalized: | |
| return 0 | |
| return len(normalized.split()) | |
| def runtime_category_code(runtime: float) -> int: | |
| if runtime < 90: | |
| return 0 | |
| if runtime < 120: | |
| return 1 | |
| return 2 | |
| def parse_release_date(value: str | None) -> datetime: | |
| if not value: | |
| return datetime(2010, 1, 1) | |
| try: | |
| return datetime.strptime(value, "%Y-%m-%d") | |
| except ValueError as exc: | |
| raise ValueError("release_date must be in YYYY-MM-DD format.") from exc | |
| def parse_feature_options(feature_names: list[str]) -> dict[str, list[str]]: | |
| options: dict[str, set[str]] = {k: set() for k in PREFIX_TO_FORM_KEY} | |
| for name in feature_names: | |
| for prefix in options: | |
| key = f"{prefix}_" | |
| if name.startswith(key) and name != f"{prefix}_other": | |
| options[prefix].add(name[len(key) :]) | |
| return {k: sorted(v) for k, v in options.items()} | |
| def _to_float(value: Any, default: float = 0.0) -> float: | |
| try: | |
| if value is None: | |
| return default | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def _to_int(value: Any, default: int = 0) -> int: | |
| try: | |
| if value is None: | |
| return default | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def build_feature_row(form_data: dict[str, Any], feature_names: list[str]) -> pd.DataFrame: | |
| row = {name: 0.0 for name in feature_names} | |
| budget = max(_to_float(form_data.get("budget"), 0.0), 0.0) | |
| popularity = max(_to_float(form_data.get("popularity"), 0.0), 0.0) | |
| runtime = max(_to_float(form_data.get("runtime"), 0.0), 0.0) | |
| release_date = parse_release_date(form_data.get("release_date")) | |
| release_season = ((release_date.month % 12) + 3) // 3 | |
| title_text = str(form_data.get("title") or "") | |
| tagline_text = str(form_data.get("tagline") or "") | |
| overview_text = str(form_data.get("overview") or "") | |
| values = { | |
| "belongs_to_collection": _to_int(form_data.get("belongs_to_collection"), 0), | |
| "homepage": _to_int(form_data.get("homepage"), 0), | |
| "has_tagline": _to_int(form_data.get("has_tagline"), 1 if tagline_text.strip() else 0), | |
| "original_language": LANGUAGE_MAPPING.get(str(form_data.get("original_language") or "").lower(), 0), | |
| "runtime": runtime, | |
| "num_of_cast": _to_float(form_data.get("num_of_cast"), 0.0), | |
| "num_of_crew": _to_float(form_data.get("num_of_crew"), 0.0), | |
| "gender_cast_1": _to_float(form_data.get("gender_cast_1"), 0.0), | |
| "gender_cast_2": _to_float(form_data.get("gender_cast_2"), 0.0), | |
| "count_cast_other": _to_float(form_data.get("count_cast_other"), 0.0), | |
| "title_word_count": _to_float(form_data.get("title_word_count"), count_words(title_text)), | |
| "tag_word_count": _to_float(form_data.get("tag_word_count"), count_words(tagline_text)), | |
| "overview_word_count": _to_float(form_data.get("overview_word_count"), count_words(overview_text)), | |
| "release_year": release_date.year, | |
| "release_month": release_date.month, | |
| "release_season": release_season, | |
| "runtime_category": runtime_category_code(runtime), | |
| "budget_log": math.log1p(budget), | |
| "popularity_log": math.log1p(popularity), | |
| } | |
| for key, value in values.items(): | |
| if key in row: | |
| row[key] = value | |
| for prefix, form_key in PREFIX_TO_FORM_KEY.items(): | |
| selected = form_data.get(form_key) or [] | |
| if not isinstance(selected, list): | |
| selected = [selected] | |
| known = 0 | |
| for item in selected: | |
| col = f"{prefix}_{item}" | |
| if col in row: | |
| row[col] = 1.0 | |
| known += 1 | |
| num_col = f"num_of_{prefix}" | |
| if num_col in row: | |
| row[num_col] = float(len(selected)) | |
| other_col = f"{prefix}_other" | |
| if other_col in row: | |
| row[other_col] = 1.0 if len(selected) > known else 0.0 | |
| df = pd.DataFrame([[row[name] for name in feature_names]], columns=feature_names) | |
| return df.replace([np.inf, -np.inf], 0).fillna(0) | |
| def predict_revenue(model: Any, form_data: dict[str, Any]) -> float: | |
| feature_names = get_model_feature_names(model) | |
| frame = build_feature_row(form_data, feature_names) | |
| pred = model.predict(frame)[0] | |
| return float(pred) | |