| from __future__ import annotations
|
|
|
| import re
|
| from functools import lru_cache
|
| from pathlib import Path
|
| from types import SimpleNamespace
|
| from typing import Iterable, Optional
|
|
|
| import pandas as pd
|
| import torch
|
|
|
| from constants import (
|
| LEGACY_MAX_TOKENS_PER_SEQ_TEXT,
|
| LEGACY_MAX_TOKENS_PER_SEQ_TITLE,
|
| MAX_TOKENS_PER_SEQ_TEXT,
|
| MAX_TOKENS_PER_SEQ_TITLE,
|
| NUMERICAL_FEATURES,
|
| get_legacy_tokenizer,
|
| get_tokenizers,
|
| )
|
| from reddit_model import RedditModel
|
| from reddit_pipeline import pipeline_row
|
|
|
|
|
| DEFAULT_MODEL_PATH = Path(__file__).resolve().parent / "artifacts" / "model_reddit_final.pth"
|
| _N_HEAD_CANDIDATES = (12, 8, 16, 6, 4, 3, 2, 1)
|
|
|
|
|
| def _safe_torch_load(model_path: Path, device: torch.device) -> dict:
|
| try:
|
| state = torch.load(model_path, map_location=device, weights_only=True)
|
| except TypeError:
|
| state = torch.load(model_path, map_location=device)
|
|
|
| if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
|
| state = state["state_dict"]
|
|
|
| if not isinstance(state, dict):
|
| raise ValueError(f"Unsupported checkpoint format in {model_path}")
|
|
|
| return state
|
|
|
|
|
| def _extract_layer_ids(state: dict, prefix: str) -> list[int]:
|
| pattern = re.compile(rf"^{re.escape(prefix)}\.layers\.(\d+)\.")
|
| layer_ids = []
|
| for key in state.keys():
|
| match = pattern.match(key)
|
| if match:
|
| layer_ids.append(int(match.group(1)))
|
| return sorted(set(layer_ids))
|
|
|
|
|
| def _infer_nb_hidden_layers(state: dict) -> int:
|
| pattern = re.compile(r"^regression_head\.(\d+)\.weight$")
|
| linear_indices = []
|
|
|
| for key, value in state.items():
|
| match = pattern.match(key)
|
| if match and isinstance(value, torch.Tensor) and value.ndim == 2:
|
| linear_indices.append(int(match.group(1)))
|
|
|
| if not linear_indices:
|
| raise ValueError("Could not infer regression head layers from checkpoint.")
|
|
|
| linear_indices.sort()
|
| return max(len(linear_indices) - 1, 0)
|
|
|
|
|
| def _infer_n_head(d_model: int) -> int:
|
| for candidate in _N_HEAD_CANDIDATES:
|
| if d_model % candidate == 0:
|
| return candidate
|
| return 1
|
|
|
|
|
| def _infer_model_spec(state: dict, device: torch.device) -> tuple[SimpleNamespace, float | None, bool]:
|
| if "embedding_text.weight" in state and "embedding_title.weight" in state:
|
| if "encoder_text.layers.0.linear1.weight" not in state:
|
| raise ValueError("Checkpoint is missing encoder_text layer weights.")
|
|
|
| d_model = int(state["embedding_text.weight"].shape[1])
|
| dim_feedforward = int(state["encoder_text.layers.0.linear1.weight"].shape[0])
|
| nb_encoder_layers = len(_extract_layer_ids(state, "encoder_text"))
|
| if nb_encoder_layers == 0:
|
| raise ValueError("Could not infer encoder_text layer count from checkpoint.")
|
|
|
| params = SimpleNamespace(
|
| DROPOUT_RATE=0.0,
|
| VOCAB_SIZE_TEXT=int(state["embedding_text.weight"].shape[0]),
|
| VOCAB_SIZE_TITLE=int(state["embedding_title.weight"].shape[0]),
|
| NB_HIDDEN_LAYERS=_infer_nb_hidden_layers(state),
|
| HIDDEN_SIZE=int(state["regression_head.0.weight"].shape[0]),
|
| D_MODEL=d_model,
|
| N_HEAD=_infer_n_head(d_model),
|
| DIM_FEEDFORWARD=dim_feedforward,
|
| NB_ENCODER_LAYERS=nb_encoder_layers,
|
| DEVICE=device,
|
| )
|
|
|
| target_mean = None
|
| if "target_mean" in state and isinstance(state["target_mean"], torch.Tensor):
|
| target_mean = float(state["target_mean"].item())
|
|
|
|
|
| return params, target_mean, True
|
|
|
| if "embedding.weight" in state:
|
| if "encoder.layers.0.linear1.weight" not in state:
|
| raise ValueError("Checkpoint is missing encoder layer weights.")
|
|
|
| d_model = int(state["embedding.weight"].shape[1])
|
| dim_feedforward = int(state["encoder.layers.0.linear1.weight"].shape[0])
|
| nb_encoder_layers = len(_extract_layer_ids(state, "encoder"))
|
| if nb_encoder_layers == 0:
|
| raise ValueError("Could not infer encoder layer count from checkpoint.")
|
|
|
| params = SimpleNamespace(
|
| DROPOUT_RATE=0.0,
|
| VOCAB_SIZE=int(state["embedding.weight"].shape[0]),
|
| NB_HIDDEN_LAYERS=_infer_nb_hidden_layers(state),
|
| HIDDEN_SIZE=int(state["regression_head.0.weight"].shape[0]),
|
| D_MODEL=d_model,
|
| N_HEAD=_infer_n_head(d_model),
|
| DIM_FEEDFORWARD=dim_feedforward,
|
| NB_ENCODER_LAYERS=nb_encoder_layers,
|
| DEVICE=device,
|
| )
|
|
|
| return params, None, False
|
|
|
| raise ValueError(
|
| "Unsupported checkpoint architecture. Expected either legacy keys "
|
| "(embedding.* / encoder.*) or dual-encoder keys "
|
| "(embedding_text.* / encoder_text.* / encoder_title.*)."
|
| )
|
|
|
|
|
| @lru_cache(maxsize=8)
|
| def _load_model_cached(model_path_str: str, device_str: str) -> tuple[RedditModel, torch.device, bool]:
|
| device = torch.device(device_str)
|
| model_path = Path(model_path_str)
|
|
|
| if not model_path.exists():
|
| raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
|
| state = _safe_torch_load(model_path, device)
|
| params, target_mean, outputs_log_space = _infer_model_spec(state, device)
|
|
|
| model = RedditModel(params, target_mean=target_mean).to(device)
|
| model.load_state_dict(state)
|
| model.eval()
|
| return model, device, outputs_log_space
|
|
|
|
|
| def _load_model(
|
| model_path: Path,
|
| device: Optional[str] = None,
|
| ) -> tuple[RedditModel, torch.device, bool]:
|
| resolved_device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| return _load_model_cached(str(model_path.resolve()), str(resolved_device))
|
|
|
|
|
| def _get_pipeline_tokenization(model: RedditModel):
|
| if getattr(model, "is_dual_encoder", False):
|
| tokenizer_title, tokenizer_text = get_tokenizers()
|
| return (
|
| tokenizer_title,
|
| tokenizer_text,
|
| MAX_TOKENS_PER_SEQ_TITLE,
|
| MAX_TOKENS_PER_SEQ_TEXT,
|
| )
|
|
|
| legacy_tokenizer = get_legacy_tokenizer()
|
| return (
|
| legacy_tokenizer,
|
| legacy_tokenizer,
|
| LEGACY_MAX_TOKENS_PER_SEQ_TITLE,
|
| LEGACY_MAX_TOKENS_PER_SEQ_TEXT,
|
| )
|
|
|
|
|
| def predict_single(
|
| *,
|
| title: str,
|
| text: str,
|
| subreddit: str,
|
| hours_ago: float,
|
| post_utc: float,
|
| model_path: Path | str = DEFAULT_MODEL_PATH,
|
| device: Optional[str] = None,
|
| ) -> float:
|
| """
|
| Predict upvotes for one post context.
|
| Both hours_ago and post_utc must be provided.
|
| """
|
| fetched_utc = post_utc + (hours_ago * 3600.0)
|
|
|
| model, model_device, outputs_log_space = _load_model(Path(model_path), device=device)
|
| tokenizer_title, tokenizer_text, max_tokens_title, max_tokens_text = _get_pipeline_tokenization(model)
|
|
|
| row = pipeline_row(
|
| title=title,
|
| text=text,
|
| subreddit=subreddit,
|
| fetched_utc=fetched_utc,
|
| hours_ago=float(hours_ago),
|
| post_utc=float(post_utc),
|
| scale=True,
|
| tokenizer_title=tokenizer_title,
|
| tokenizer_text=tokenizer_text,
|
| max_tokens_title=max_tokens_title,
|
| max_tokens_text=max_tokens_text,
|
| )
|
|
|
| numerical = torch.tensor(
|
| row[NUMERICAL_FEATURES].to_numpy(),
|
| dtype=torch.float32,
|
| device=model_device,
|
| )
|
|
|
| with torch.no_grad():
|
| pred = model(row, numerical)
|
| if outputs_log_space:
|
| pred = torch.exp(pred)
|
|
|
| return float(pred.detach().cpu().item())
|
|
|
|
|
| def predict_batch(
|
| *,
|
| title: str,
|
| text: str,
|
| subreddit: str,
|
| hours_ago_list: Iterable[float],
|
| post_utc_list: Iterable[float],
|
| model_path: Path | str = DEFAULT_MODEL_PATH,
|
| device: Optional[str] = None,
|
| ) -> list[dict]:
|
| """
|
| Predict upvotes for multiple candidate posting times.
|
| Both hours_ago_list and post_utc_list must be provided and have the same length.
|
|
|
| Returns a list of dictionaries containing:
|
| - hours_ago
|
| - post_utc
|
| - predicted_upvotes
|
| """
|
| hours_values = [float(h) for h in hours_ago_list]
|
| post_values = [float(p) for p in post_utc_list]
|
|
|
| if len(hours_values) != len(post_values):
|
| raise ValueError("hours_ago_list and post_utc_list must have the same length.")
|
| if not hours_values:
|
| raise ValueError("time lists cannot be empty.")
|
|
|
| model, model_device, outputs_log_space = _load_model(Path(model_path), device=device)
|
| tokenizer_title, tokenizer_text, max_tokens_title, max_tokens_text = _get_pipeline_tokenization(model)
|
|
|
| rows = []
|
| for h, p in zip(hours_values, post_values):
|
| fetched_utc = p + (h * 3600.0)
|
| row = pipeline_row(
|
| title=title,
|
| text=text,
|
| subreddit=subreddit,
|
| fetched_utc=fetched_utc,
|
| hours_ago=h,
|
| post_utc=p,
|
| scale=True,
|
| tokenizer_title=tokenizer_title,
|
| tokenizer_text=tokenizer_text,
|
| max_tokens_title=max_tokens_title,
|
| max_tokens_text=max_tokens_text,
|
| )
|
| rows.append(row)
|
|
|
| batch_df = pd.concat(rows, ignore_index=True)
|
|
|
| numerical = torch.tensor(
|
| batch_df[NUMERICAL_FEATURES].to_numpy(),
|
| dtype=torch.float32,
|
| device=model_device,
|
| )
|
|
|
| with torch.no_grad():
|
| preds = model(batch_df, numerical)
|
| if outputs_log_space:
|
| preds = torch.exp(preds)
|
| preds = preds.detach().cpu().numpy()
|
|
|
| return [
|
| {
|
| "hours_ago": h,
|
| "post_utc": p,
|
| "prediction": float(pred),
|
| }
|
| for h, p, pred in zip(hours_values, post_values, preds)
|
| ]
|
|
|