Reddit / models /reddit_predict.py
cyrilfrl's picture
hope it works this time
44748ce verified
from __future__ import annotations
import re
from functools import lru_cache
from pathlib import Path
from types import SimpleNamespace
from typing import Iterable, Optional
import pandas as pd
import torch
from constants import (
LEGACY_MAX_TOKENS_PER_SEQ_TEXT,
LEGACY_MAX_TOKENS_PER_SEQ_TITLE,
MAX_TOKENS_PER_SEQ_TEXT,
MAX_TOKENS_PER_SEQ_TITLE,
NUMERICAL_FEATURES,
get_legacy_tokenizer,
get_tokenizers,
)
from reddit_model import RedditModel
from reddit_pipeline import pipeline_row
DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth"
_N_HEAD_CANDIDATES = (12, 8, 16, 6, 4, 3, 2, 1)
def _safe_torch_load(model_path: Path, device: torch.device) -> dict:
try:
state = torch.load(model_path, map_location=device, weights_only=True)
except TypeError:
state = torch.load(model_path, map_location=device)
if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
state = state["state_dict"]
if not isinstance(state, dict):
raise ValueError(f"Unsupported checkpoint format in {model_path}")
return state
def _extract_layer_ids(state: dict, prefix: str) -> list[int]:
pattern = re.compile(rf"^{re.escape(prefix)}\.layers\.(\d+)\.")
layer_ids = []
for key in state.keys():
match = pattern.match(key)
if match:
layer_ids.append(int(match.group(1)))
return sorted(set(layer_ids))
def _infer_nb_hidden_layers(state: dict) -> int:
pattern = re.compile(r"^regression_head\.(\d+)\.weight$")
linear_indices = []
for key, value in state.items():
match = pattern.match(key)
if match and isinstance(value, torch.Tensor) and value.ndim == 2:
linear_indices.append(int(match.group(1)))
if not linear_indices:
raise ValueError("Could not infer regression head layers from checkpoint.")
linear_indices.sort()
return max(len(linear_indices) - 1, 0)
def _infer_n_head(d_model: int) -> int:
for candidate in _N_HEAD_CANDIDATES:
if d_model % candidate == 0:
return candidate
return 1
def _infer_model_spec(state: dict, device: torch.device) -> tuple[SimpleNamespace, float | None, bool]:
if "embedding_text.weight" in state and "embedding_title.weight" in state:
if "encoder_text.layers.0.linear1.weight" not in state:
raise ValueError("Checkpoint is missing encoder_text layer weights.")
d_model = int(state["embedding_text.weight"].shape[1])
dim_feedforward = int(state["encoder_text.layers.0.linear1.weight"].shape[0])
nb_encoder_layers = len(_extract_layer_ids(state, "encoder_text"))
if nb_encoder_layers == 0:
raise ValueError("Could not infer encoder_text layer count from checkpoint.")
params = SimpleNamespace(
DROPOUT_RATE=0.0,
VOCAB_SIZE_TEXT=int(state["embedding_text.weight"].shape[0]),
VOCAB_SIZE_TITLE=int(state["embedding_title.weight"].shape[0]),
NB_HIDDEN_LAYERS=_infer_nb_hidden_layers(state),
HIDDEN_SIZE=int(state["regression_head.0.weight"].shape[0]),
D_MODEL=d_model,
N_HEAD=_infer_n_head(d_model),
DIM_FEEDFORWARD=dim_feedforward,
NB_ENCODER_LAYERS=nb_encoder_layers,
DEVICE=device,
)
target_mean = None
if "target_mean" in state and isinstance(state["target_mean"], torch.Tensor):
target_mean = float(state["target_mean"].item())
# New checkpoints were trained in log-space (Tweedie), so inference must exponentiate outputs.
return params, target_mean, True
if "embedding.weight" in state:
if "encoder.layers.0.linear1.weight" not in state:
raise ValueError("Checkpoint is missing encoder layer weights.")
d_model = int(state["embedding.weight"].shape[1])
dim_feedforward = int(state["encoder.layers.0.linear1.weight"].shape[0])
nb_encoder_layers = len(_extract_layer_ids(state, "encoder"))
if nb_encoder_layers == 0:
raise ValueError("Could not infer encoder layer count from checkpoint.")
params = SimpleNamespace(
DROPOUT_RATE=0.0,
VOCAB_SIZE=int(state["embedding.weight"].shape[0]),
NB_HIDDEN_LAYERS=_infer_nb_hidden_layers(state),
HIDDEN_SIZE=int(state["regression_head.0.weight"].shape[0]),
D_MODEL=d_model,
N_HEAD=_infer_n_head(d_model),
DIM_FEEDFORWARD=dim_feedforward,
NB_ENCODER_LAYERS=nb_encoder_layers,
DEVICE=device,
)
return params, None, False
raise ValueError(
"Unsupported checkpoint architecture. Expected either legacy keys "
"(embedding.* / encoder.*) or dual-encoder keys "
"(embedding_text.* / encoder_text.* / encoder_title.*)."
)
@lru_cache(maxsize=8)
def _load_model_cached(model_path_str: str, device_str: str) -> tuple[RedditModel, torch.device, bool]:
device = torch.device(device_str)
model_path = Path(model_path_str)
if not model_path.exists():
raise FileNotFoundError(f"Model file not found: {model_path}")
state = _safe_torch_load(model_path, device)
params, target_mean, outputs_log_space = _infer_model_spec(state, device)
model = RedditModel(params, target_mean=target_mean).to(device)
model.load_state_dict(state)
model.eval()
return model, device, outputs_log_space
def _load_model(
model_path: Path,
device: Optional[str] = None,
) -> tuple[RedditModel, torch.device, bool]:
resolved_device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
return _load_model_cached(str(model_path.resolve()), str(resolved_device))
def _get_pipeline_tokenization(model: RedditModel):
if getattr(model, "is_dual_encoder", False):
tokenizer_title, tokenizer_text = get_tokenizers()
return (
tokenizer_title,
tokenizer_text,
MAX_TOKENS_PER_SEQ_TITLE,
MAX_TOKENS_PER_SEQ_TEXT,
)
legacy_tokenizer = get_legacy_tokenizer()
return (
legacy_tokenizer,
legacy_tokenizer,
LEGACY_MAX_TOKENS_PER_SEQ_TITLE,
LEGACY_MAX_TOKENS_PER_SEQ_TEXT,
)
def predict_single(
*,
title: str,
text: str,
subreddit: str,
hours_ago: float,
post_utc: float,
model_path: Path | str = DEFAULT_MODEL_PATH,
device: Optional[str] = None,
) -> float:
"""
Predict upvotes for one post context.
Both hours_ago and post_utc must be provided.
"""
fetched_utc = post_utc + (hours_ago * 3600.0)
model, model_device, outputs_log_space = _load_model(Path(model_path), device=device)
tokenizer_title, tokenizer_text, max_tokens_title, max_tokens_text = _get_pipeline_tokenization(model)
row = pipeline_row(
title=title,
text=text,
subreddit=subreddit,
fetched_utc=fetched_utc,
hours_ago=float(hours_ago),
post_utc=float(post_utc),
scale=True,
tokenizer_title=tokenizer_title,
tokenizer_text=tokenizer_text,
max_tokens_title=max_tokens_title,
max_tokens_text=max_tokens_text,
)
numerical = torch.tensor(
row[NUMERICAL_FEATURES].to_numpy(),
dtype=torch.float32,
device=model_device,
)
with torch.no_grad():
pred = model(row, numerical)
if outputs_log_space:
pred = torch.exp(pred)
return float(pred.detach().cpu().item())
def predict_batch(
*,
title: str,
text: str,
subreddit: str,
hours_ago_list: Iterable[float],
post_utc_list: Iterable[float],
model_path: Path | str = DEFAULT_MODEL_PATH,
device: Optional[str] = None,
) -> list[dict]:
"""
Predict upvotes for multiple candidate posting times.
Both hours_ago_list and post_utc_list must be provided and have the same length.
Returns a list of dictionaries containing:
- hours_ago
- post_utc
- predicted_upvotes
"""
hours_values = [float(h) for h in hours_ago_list]
post_values = [float(p) for p in post_utc_list]
if len(hours_values) != len(post_values):
raise ValueError("hours_ago_list and post_utc_list must have the same length.")
if not hours_values:
raise ValueError("time lists cannot be empty.")
model, model_device, outputs_log_space = _load_model(Path(model_path), device=device)
tokenizer_title, tokenizer_text, max_tokens_title, max_tokens_text = _get_pipeline_tokenization(model)
rows = []
for h, p in zip(hours_values, post_values):
fetched_utc = p + (h * 3600.0)
row = pipeline_row(
title=title,
text=text,
subreddit=subreddit,
fetched_utc=fetched_utc,
hours_ago=h,
post_utc=p,
scale=True,
tokenizer_title=tokenizer_title,
tokenizer_text=tokenizer_text,
max_tokens_title=max_tokens_title,
max_tokens_text=max_tokens_text,
)
rows.append(row)
batch_df = pd.concat(rows, ignore_index=True)
numerical = torch.tensor(
batch_df[NUMERICAL_FEATURES].to_numpy(),
dtype=torch.float32,
device=model_device,
)
with torch.no_grad():
preds = model(batch_df, numerical)
if outputs_log_space:
preds = torch.exp(preds)
preds = preds.detach().cpu().numpy()
return [
{
"hours_ago": h,
"post_utc": p,
"prediction": float(pred),
}
for h, p, pred in zip(hours_values, post_values, preds)
]