"""Model load and single/batch classification (no Gradio).""" from __future__ import annotations import json import os from typing import Any, Callable, Dict, List, Optional, Tuple import torch from huggingface_hub import login from peft import PeftConfig, PeftModel from transformers import AutoModelForSequenceClassification, AutoTokenizer from preprocess import preprocess_text MODEL_ID = os.environ.get( "MODEL_ID", "apps1/wallet_bert_13_gpu", ) MAX_LENGTH = 512 _tokenizer = None _model = None _device: Optional[torch.device] = None def _ensure_hf_login() -> None: token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not token: raise RuntimeError( "HF_TOKEN or HUGGINGFACE_HUB_TOKEN is not set. Add HF_TOKEN under Space repository secrets." ) login(token=token, add_to_git_credential=False) def _load_model() -> Tuple[Any, Any, torch.device]: global _tokenizer, _model, _device if _model is not None and _tokenizer is not None and _device is not None: return _tokenizer, _model, _device _ensure_hf_login() print("Loading PEFT adapter from Hub:", MODEL_ID) peft_config = PeftConfig.from_pretrained(MODEL_ID) base_id = peft_config.base_model_name_or_path print("Base model:", base_id) tokenizer = AutoTokenizer.from_pretrained(base_id, trust_remote_code=True) base = AutoModelForSequenceClassification.from_pretrained( base_id, trust_remote_code=True, ) model = PeftModel.from_pretrained(base, MODEL_ID) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() _tokenizer = tokenizer _model = model _device = device print("Model ready on device:", device) return tokenizer, model, device def _classification_config(model: Any) -> Any: inner = getattr(getattr(model, "base_model", None), "model", None) if inner is not None and getattr(inner, "config", None) is not None: return inner.config return getattr(model, "config", None) def _id2label_map(model: Any) -> Dict[int, str]: cfg = _classification_config(model) raw = getattr(cfg, "id2label", None) or {} if cfg is not None else {} out: Dict[int, str] = {} for k, v in raw.items(): try: ik = int(k) except (TypeError, ValueError): continue out[ik] = str(v) if not out: n = int(getattr(cfg, "num_labels", 0) or 0) if cfg is not None else 0 for i in range(n): out[i] = f"class_{i}" return out def _prob_key(label: str) -> str: safe = "".join(ch if ch.isalnum() or ch == "_" else "_" for ch in label) return f"P_{safe}" def _canonical_v3_teacher_label(pred_idx: int, raw_label: str, num_labels: int) -> str: s = raw_label.lower().strip().replace(" ", "_").replace("-", "_") if "non_bill" in s or s == "nonbill": return "non_bill" if s == "bill" or (s.endswith("bill") and "non" not in s): return "bill" if num_labels == 2: return "non_bill" if pred_idx == 0 else "bill" return "non_bill" if pred_idx == 0 else "bill" def _build_payload( raw_text: str, parameterized_text: str, probs_1d: torch.Tensor, id2label: Dict[int, str], ) -> Dict[str, Any]: probs = probs_1d.detach().float().cpu() pred_idx = int(torch.argmax(probs).item()) payload: Dict[str, Any] = { "input_preview": raw_text[:500] + ("..." if len(raw_text) > 500 else ""), "parameterized_text": parameterized_text, "preprocessing_status": "ok", "predicted_index": pred_idx, "predicted_class": id2label.get(pred_idx, str(pred_idx)), } for i in range(probs.numel()): label = id2label.get(i, str(i)) payload[_prob_key(label)] = round(float(probs[i].item()), 6) return payload def _classify_success_parts( raw_text: str, parameterized_text: str, probs_1d: torch.Tensor, id2label: Dict[int, str], num_labels: int, ) -> Tuple[Dict[str, Any], str, float]: payload = _build_payload(raw_text, parameterized_text, probs_1d, id2label) pred_idx = int(payload["predicted_index"]) v3_teacher_label = _canonical_v3_teacher_label( pred_idx, str(id2label.get(pred_idx, "")), num_labels ) probs = probs_1d.detach().float().cpu() v3_teacher_prob = round(float(probs[pred_idx].item()), 6) return payload, v3_teacher_label, v3_teacher_prob def classify_one(text: Optional[str]) -> str: raw = "" if text is None else str(text) stripped = raw.strip() parameterized_text = preprocess_text(stripped) try: tokenizer, model, device = _load_model() id2label = _id2label_map(model) cfg = _classification_config(model) num_labels = int(getattr(cfg, "num_labels", 0) or 0) if cfg else 0 inputs = tokenizer( parameterized_text, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True, ).to(device) with torch.inference_mode(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1).squeeze(0) payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts( stripped, parameterized_text, probs, id2label, num_labels ) body = { "v3_teacher_label": v3_teacher_label, "v3_teacher_prob": v3_teacher_prob, "v3_teacher_results": payload, } except Exception as exc: print("classify_one error:", exc) payload = { "input_preview": stripped[:500] + ("..." if len(stripped) > 500 else ""), "parameterized_text": parameterized_text, "preprocessing_status": "error", "message": str(exc), "predicted_index": None, "predicted_class": None, } body = { "v3_teacher_label": None, "v3_teacher_prob": None, "v3_teacher_results": payload, } return json.dumps(body, indent=2) def classify_batch( texts: List[str], batch_size: int = 32, progress_cb: Optional[Callable[[int, int], None]] = None, ) -> List[Tuple[str, str, Optional[float]]]: """Return per row: (v3_teacher_results JSON string, v3_teacher_label, v3_teacher_prob) for CSV columns.""" n = len(texts) out: List[Optional[Tuple[str, str, Optional[float]]]] = [None] * n if n == 0: return [] tokenizer, model, device = _load_model() id2label = _id2label_map(model) cfg = _classification_config(model) num_labels = int(getattr(cfg, "num_labels", 0) or 0) if cfg else 0 preprocessed: List[str] = [""] * n done = 0 for i, t in enumerate(texts): try: preprocessed[i] = preprocess_text("" if t is None else str(t)) except Exception as exc: print("preprocess error row", i, exc) out[i] = ( json.dumps( {"error": str(exc), "preprocessing_status": "error"}, indent=2, ), "", None, ) done += 1 if progress_cb: progress_cb(done, n) model_indices = [i for i in range(n) if out[i] is None] for start in range(0, len(model_indices), batch_size): chunk = model_indices[start : start + batch_size] batch_param = [preprocessed[i] for i in chunk] batch_raw = ["" if texts[i] is None else str(texts[i]) for i in chunk] try: inputs = tokenizer( batch_param, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True, ).to(device) with torch.inference_mode(): logits = model(**inputs).logits probs_b = torch.softmax(logits, dim=-1) for j, row_i in enumerate(chunk): payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts( batch_raw[j], batch_param[j], probs_b[j], id2label, num_labels, ) out[row_i] = (json.dumps(payload, indent=2), v3_teacher_label, v3_teacher_prob) done += 1 if progress_cb: progress_cb(done, n) except Exception as exc: print("batch forward error, falling back row-wise:", exc) for row_i in chunk: try: single = tokenizer( preprocessed[row_i], return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding=True, ).to(device) with torch.inference_mode(): logits = model(**single).logits probs = torch.softmax(logits, dim=-1).squeeze(0) raw_s = "" if texts[row_i] is None else str(texts[row_i]) payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts( raw_s, preprocessed[row_i], probs, id2label, num_labels, ) out[row_i] = (json.dumps(payload, indent=2), v3_teacher_label, v3_teacher_prob) except Exception as exc2: print("row error", row_i, exc2) out[row_i] = ( json.dumps( {"error": str(exc2), "preprocessing_status": "error"}, indent=2, ), "", None, ) done += 1 if progress_cb: progress_cb(done, n) return [ o if o is not None else ("{}", "", None) for o in out ]