from __future__ import annotations import re from functools import lru_cache from pathlib import Path from types import SimpleNamespace from typing import Iterable, Optional import pandas as pd import torch from constants import ( LEGACY_MAX_TOKENS_PER_SEQ_TEXT, LEGACY_MAX_TOKENS_PER_SEQ_TITLE, MAX_TOKENS_PER_SEQ_TEXT, MAX_TOKENS_PER_SEQ_TITLE, NUMERICAL_FEATURES, get_legacy_tokenizer, get_tokenizers, ) from reddit_model import RedditModel from reddit_pipeline import pipeline_row DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth" _N_HEAD_CANDIDATES = (12, 8, 16, 6, 4, 3, 2, 1) def _safe_torch_load(model_path: Path, device: torch.device) -> dict: try: state = torch.load(model_path, map_location=device, weights_only=True) except TypeError: state = torch.load(model_path, map_location=device) if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict): state = state["state_dict"] if not isinstance(state, dict): raise ValueError(f"Unsupported checkpoint format in {model_path}") return state def _extract_layer_ids(state: dict, prefix: str) -> list[int]: pattern = re.compile(rf"^{re.escape(prefix)}\.layers\.(\d+)\.") layer_ids = [] for key in state.keys(): match = pattern.match(key) if match: layer_ids.append(int(match.group(1))) return sorted(set(layer_ids)) def _infer_nb_hidden_layers(state: dict) -> int: pattern = re.compile(r"^regression_head\.(\d+)\.weight$") linear_indices = [] for key, value in state.items(): match = pattern.match(key) if match and isinstance(value, torch.Tensor) and value.ndim == 2: linear_indices.append(int(match.group(1))) if not linear_indices: raise ValueError("Could not infer regression head layers from checkpoint.") linear_indices.sort() return max(len(linear_indices) - 1, 0) def _infer_n_head(d_model: int) -> int: for candidate in _N_HEAD_CANDIDATES: if d_model % candidate == 0: return candidate return 1 def _infer_model_spec(state: dict, device: torch.device) -> tuple[SimpleNamespace, float | None, bool]: if "embedding_text.weight" in state and "embedding_title.weight" in state: if "encoder_text.layers.0.linear1.weight" not in state: raise ValueError("Checkpoint is missing encoder_text layer weights.") d_model = int(state["embedding_text.weight"].shape[1]) dim_feedforward = int(state["encoder_text.layers.0.linear1.weight"].shape[0]) nb_encoder_layers = len(_extract_layer_ids(state, "encoder_text")) if nb_encoder_layers == 0: raise ValueError("Could not infer encoder_text layer count from checkpoint.") params = SimpleNamespace( DROPOUT_RATE=0.0, VOCAB_SIZE_TEXT=int(state["embedding_text.weight"].shape[0]), VOCAB_SIZE_TITLE=int(state["embedding_title.weight"].shape[0]), NB_HIDDEN_LAYERS=_infer_nb_hidden_layers(state), HIDDEN_SIZE=int(state["regression_head.0.weight"].shape[0]), D_MODEL=d_model, N_HEAD=_infer_n_head(d_model), DIM_FEEDFORWARD=dim_feedforward, NB_ENCODER_LAYERS=nb_encoder_layers, DEVICE=device, ) target_mean = None if "target_mean" in state and isinstance(state["target_mean"], torch.Tensor): target_mean = float(state["target_mean"].item()) # New checkpoints were trained in log-space (Tweedie), so inference must exponentiate outputs. return params, target_mean, True if "embedding.weight" in state: if "encoder.layers.0.linear1.weight" not in state: raise ValueError("Checkpoint is missing encoder layer weights.") d_model = int(state["embedding.weight"].shape[1]) dim_feedforward = int(state["encoder.layers.0.linear1.weight"].shape[0]) nb_encoder_layers = len(_extract_layer_ids(state, "encoder")) if nb_encoder_layers == 0: raise ValueError("Could not infer encoder layer count from checkpoint.") params = SimpleNamespace( DROPOUT_RATE=0.0, VOCAB_SIZE=int(state["embedding.weight"].shape[0]), NB_HIDDEN_LAYERS=_infer_nb_hidden_layers(state), HIDDEN_SIZE=int(state["regression_head.0.weight"].shape[0]), D_MODEL=d_model, N_HEAD=_infer_n_head(d_model), DIM_FEEDFORWARD=dim_feedforward, NB_ENCODER_LAYERS=nb_encoder_layers, DEVICE=device, ) return params, None, False raise ValueError( "Unsupported checkpoint architecture. Expected either legacy keys " "(embedding.* / encoder.*) or dual-encoder keys " "(embedding_text.* / encoder_text.* / encoder_title.*)." ) @lru_cache(maxsize=8) def _load_model_cached(model_path_str: str, device_str: str) -> tuple[RedditModel, torch.device, bool]: device = torch.device(device_str) model_path = Path(model_path_str) if not model_path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") state = _safe_torch_load(model_path, device) params, target_mean, outputs_log_space = _infer_model_spec(state, device) model = RedditModel(params, target_mean=target_mean).to(device) model.load_state_dict(state) model.eval() return model, device, outputs_log_space def _load_model( model_path: Path, device: Optional[str] = None, ) -> tuple[RedditModel, torch.device, bool]: resolved_device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") return _load_model_cached(str(model_path.resolve()), str(resolved_device)) def _get_pipeline_tokenization(model: RedditModel): if getattr(model, "is_dual_encoder", False): tokenizer_title, tokenizer_text = get_tokenizers() return ( tokenizer_title, tokenizer_text, MAX_TOKENS_PER_SEQ_TITLE, MAX_TOKENS_PER_SEQ_TEXT, ) legacy_tokenizer = get_legacy_tokenizer() return ( legacy_tokenizer, legacy_tokenizer, LEGACY_MAX_TOKENS_PER_SEQ_TITLE, LEGACY_MAX_TOKENS_PER_SEQ_TEXT, ) def predict_single( *, title: str, text: str, subreddit: str, hours_ago: float, post_utc: float, model_path: Path | str = DEFAULT_MODEL_PATH, device: Optional[str] = None, ) -> float: """ Predict upvotes for one post context. Both hours_ago and post_utc must be provided. """ fetched_utc = post_utc + (hours_ago * 3600.0) model, model_device, outputs_log_space = _load_model(Path(model_path), device=device) tokenizer_title, tokenizer_text, max_tokens_title, max_tokens_text = _get_pipeline_tokenization(model) row = pipeline_row( title=title, text=text, subreddit=subreddit, fetched_utc=fetched_utc, hours_ago=float(hours_ago), post_utc=float(post_utc), scale=True, tokenizer_title=tokenizer_title, tokenizer_text=tokenizer_text, max_tokens_title=max_tokens_title, max_tokens_text=max_tokens_text, ) numerical = torch.tensor( row[NUMERICAL_FEATURES].to_numpy(), dtype=torch.float32, device=model_device, ) with torch.no_grad(): pred = model(row, numerical) if outputs_log_space: pred = torch.exp(pred) return float(pred.detach().cpu().item()) def predict_batch( *, title: str, text: str, subreddit: str, hours_ago_list: Iterable[float], post_utc_list: Iterable[float], model_path: Path | str = DEFAULT_MODEL_PATH, device: Optional[str] = None, ) -> list[dict]: """ Predict upvotes for multiple candidate posting times. Both hours_ago_list and post_utc_list must be provided and have the same length. Returns a list of dictionaries containing: - hours_ago - post_utc - predicted_upvotes """ hours_values = [float(h) for h in hours_ago_list] post_values = [float(p) for p in post_utc_list] if len(hours_values) != len(post_values): raise ValueError("hours_ago_list and post_utc_list must have the same length.") if not hours_values: raise ValueError("time lists cannot be empty.") model, model_device, outputs_log_space = _load_model(Path(model_path), device=device) tokenizer_title, tokenizer_text, max_tokens_title, max_tokens_text = _get_pipeline_tokenization(model) rows = [] for h, p in zip(hours_values, post_values): fetched_utc = p + (h * 3600.0) row = pipeline_row( title=title, text=text, subreddit=subreddit, fetched_utc=fetched_utc, hours_ago=h, post_utc=p, scale=True, tokenizer_title=tokenizer_title, tokenizer_text=tokenizer_text, max_tokens_title=max_tokens_title, max_tokens_text=max_tokens_text, ) rows.append(row) batch_df = pd.concat(rows, ignore_index=True) numerical = torch.tensor( batch_df[NUMERICAL_FEATURES].to_numpy(), dtype=torch.float32, device=model_device, ) with torch.no_grad(): preds = model(batch_df, numerical) if outputs_log_space: preds = torch.exp(preds) preds = preds.detach().cpu().numpy() return [ { "hours_ago": h, "post_utc": p, "prediction": float(pred), } for h, p, pred in zip(hours_values, post_values, preds) ]