Spaces:
Sleeping
Sleeping
File size: 5,219 Bytes
b490ee7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | from __future__ import annotations
import math
import pickle
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
LANGUAGE_MAPPING = {"en": 1, "zh": 2, "ja": 3}
PREFIX_TO_FORM_KEY = {
"genres": "genres",
"production_companies": "production_companies",
"Keywords": "keywords",
"cast": "cast",
}
def load_model(model_path: str | Path) -> Any:
with Path(model_path).open("rb") as file:
return pickle.load(file)
def get_model_feature_names(model: Any) -> list[str]:
if not hasattr(model, "feature_names_in_"):
raise ValueError("Model does not expose feature_names_in_.")
return list(model.feature_names_in_)
def count_words(text: str | None) -> int:
if text is None:
return 0
normalized = str(text).strip()
if not normalized:
return 0
return len(normalized.split())
def runtime_category_code(runtime: float) -> int:
if runtime < 90:
return 0
if runtime < 120:
return 1
return 2
def parse_release_date(value: str | None) -> datetime:
if not value:
return datetime(2010, 1, 1)
try:
return datetime.strptime(value, "%Y-%m-%d")
except ValueError as exc:
raise ValueError("release_date must be in YYYY-MM-DD format.") from exc
def parse_feature_options(feature_names: list[str]) -> dict[str, list[str]]:
options: dict[str, set[str]] = {k: set() for k in PREFIX_TO_FORM_KEY}
for name in feature_names:
for prefix in options:
key = f"{prefix}_"
if name.startswith(key) and name != f"{prefix}_other":
options[prefix].add(name[len(key) :])
return {k: sorted(v) for k, v in options.items()}
def _to_float(value: Any, default: float = 0.0) -> float:
try:
if value is None:
return default
return float(value)
except (TypeError, ValueError):
return default
def _to_int(value: Any, default: int = 0) -> int:
try:
if value is None:
return default
return int(value)
except (TypeError, ValueError):
return default
def build_feature_row(form_data: dict[str, Any], feature_names: list[str]) -> pd.DataFrame:
row = {name: 0.0 for name in feature_names}
budget = max(_to_float(form_data.get("budget"), 0.0), 0.0)
popularity = max(_to_float(form_data.get("popularity"), 0.0), 0.0)
runtime = max(_to_float(form_data.get("runtime"), 0.0), 0.0)
release_date = parse_release_date(form_data.get("release_date"))
release_season = ((release_date.month % 12) + 3) // 3
title_text = str(form_data.get("title") or "")
tagline_text = str(form_data.get("tagline") or "")
overview_text = str(form_data.get("overview") or "")
values = {
"belongs_to_collection": _to_int(form_data.get("belongs_to_collection"), 0),
"homepage": _to_int(form_data.get("homepage"), 0),
"has_tagline": _to_int(form_data.get("has_tagline"), 1 if tagline_text.strip() else 0),
"original_language": LANGUAGE_MAPPING.get(str(form_data.get("original_language") or "").lower(), 0),
"runtime": runtime,
"num_of_cast": _to_float(form_data.get("num_of_cast"), 0.0),
"num_of_crew": _to_float(form_data.get("num_of_crew"), 0.0),
"gender_cast_1": _to_float(form_data.get("gender_cast_1"), 0.0),
"gender_cast_2": _to_float(form_data.get("gender_cast_2"), 0.0),
"count_cast_other": _to_float(form_data.get("count_cast_other"), 0.0),
"title_word_count": _to_float(form_data.get("title_word_count"), count_words(title_text)),
"tag_word_count": _to_float(form_data.get("tag_word_count"), count_words(tagline_text)),
"overview_word_count": _to_float(form_data.get("overview_word_count"), count_words(overview_text)),
"release_year": release_date.year,
"release_month": release_date.month,
"release_season": release_season,
"runtime_category": runtime_category_code(runtime),
"budget_log": math.log1p(budget),
"popularity_log": math.log1p(popularity),
}
for key, value in values.items():
if key in row:
row[key] = value
for prefix, form_key in PREFIX_TO_FORM_KEY.items():
selected = form_data.get(form_key) or []
if not isinstance(selected, list):
selected = [selected]
known = 0
for item in selected:
col = f"{prefix}_{item}"
if col in row:
row[col] = 1.0
known += 1
num_col = f"num_of_{prefix}"
if num_col in row:
row[num_col] = float(len(selected))
other_col = f"{prefix}_other"
if other_col in row:
row[other_col] = 1.0 if len(selected) > known else 0.0
df = pd.DataFrame([[row[name] for name in feature_names]], columns=feature_names)
return df.replace([np.inf, -np.inf], 0).fillna(0)
def predict_revenue(model: Any, form_data: dict[str, Any]) -> float:
feature_names = get_model_feature_names(model)
frame = build_feature_row(form_data, feature_names)
pred = model.predict(frame)[0]
return float(pred)
|