Spaces:
Sleeping
Sleeping
| """Model load and single/batch classification (no Gradio).""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Any, Callable, Dict, List, Optional, Tuple | |
| import torch | |
| from huggingface_hub import login | |
| from peft import PeftConfig, PeftModel | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from preprocess import preprocess_text | |
| MODEL_ID = os.environ.get( | |
| "MODEL_ID", | |
| "apps1/wallet_bert_13_gpu", | |
| ) | |
| MAX_LENGTH = 512 | |
| _tokenizer = None | |
| _model = None | |
| _device: Optional[torch.device] = None | |
| def _ensure_hf_login() -> None: | |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| if not token: | |
| raise RuntimeError( | |
| "HF_TOKEN or HUGGINGFACE_HUB_TOKEN is not set. Add HF_TOKEN under Space repository secrets." | |
| ) | |
| login(token=token, add_to_git_credential=False) | |
| def _load_model() -> Tuple[Any, Any, torch.device]: | |
| global _tokenizer, _model, _device | |
| if _model is not None and _tokenizer is not None and _device is not None: | |
| return _tokenizer, _model, _device | |
| _ensure_hf_login() | |
| print("Loading PEFT adapter from Hub:", MODEL_ID) | |
| peft_config = PeftConfig.from_pretrained(MODEL_ID) | |
| base_id = peft_config.base_model_name_or_path | |
| print("Base model:", base_id) | |
| tokenizer = AutoTokenizer.from_pretrained(base_id, trust_remote_code=True) | |
| base = AutoModelForSequenceClassification.from_pretrained( | |
| base_id, | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base, MODEL_ID) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| _tokenizer = tokenizer | |
| _model = model | |
| _device = device | |
| print("Model ready on device:", device) | |
| return tokenizer, model, device | |
| def _classification_config(model: Any) -> Any: | |
| inner = getattr(getattr(model, "base_model", None), "model", None) | |
| if inner is not None and getattr(inner, "config", None) is not None: | |
| return inner.config | |
| return getattr(model, "config", None) | |
| def _id2label_map(model: Any) -> Dict[int, str]: | |
| cfg = _classification_config(model) | |
| raw = getattr(cfg, "id2label", None) or {} if cfg is not None else {} | |
| out: Dict[int, str] = {} | |
| for k, v in raw.items(): | |
| try: | |
| ik = int(k) | |
| except (TypeError, ValueError): | |
| continue | |
| out[ik] = str(v) | |
| if not out: | |
| n = int(getattr(cfg, "num_labels", 0) or 0) if cfg is not None else 0 | |
| for i in range(n): | |
| out[i] = f"class_{i}" | |
| return out | |
| def _prob_key(label: str) -> str: | |
| safe = "".join(ch if ch.isalnum() or ch == "_" else "_" for ch in label) | |
| return f"P_{safe}" | |
| def _canonical_v3_teacher_label(pred_idx: int, raw_label: str, num_labels: int) -> str: | |
| s = raw_label.lower().strip().replace(" ", "_").replace("-", "_") | |
| if "non_bill" in s or s == "nonbill": | |
| return "non_bill" | |
| if s == "bill" or (s.endswith("bill") and "non" not in s): | |
| return "bill" | |
| if num_labels == 2: | |
| return "non_bill" if pred_idx == 0 else "bill" | |
| return "non_bill" if pred_idx == 0 else "bill" | |
| def _build_payload( | |
| raw_text: str, | |
| parameterized_text: str, | |
| probs_1d: torch.Tensor, | |
| id2label: Dict[int, str], | |
| ) -> Dict[str, Any]: | |
| probs = probs_1d.detach().float().cpu() | |
| pred_idx = int(torch.argmax(probs).item()) | |
| payload: Dict[str, Any] = { | |
| "input_preview": raw_text[:500] + ("..." if len(raw_text) > 500 else ""), | |
| "parameterized_text": parameterized_text, | |
| "preprocessing_status": "ok", | |
| "predicted_index": pred_idx, | |
| "predicted_class": id2label.get(pred_idx, str(pred_idx)), | |
| } | |
| for i in range(probs.numel()): | |
| label = id2label.get(i, str(i)) | |
| payload[_prob_key(label)] = round(float(probs[i].item()), 6) | |
| return payload | |
| def _classify_success_parts( | |
| raw_text: str, | |
| parameterized_text: str, | |
| probs_1d: torch.Tensor, | |
| id2label: Dict[int, str], | |
| num_labels: int, | |
| ) -> Tuple[Dict[str, Any], str, float]: | |
| payload = _build_payload(raw_text, parameterized_text, probs_1d, id2label) | |
| pred_idx = int(payload["predicted_index"]) | |
| v3_teacher_label = _canonical_v3_teacher_label( | |
| pred_idx, str(id2label.get(pred_idx, "")), num_labels | |
| ) | |
| probs = probs_1d.detach().float().cpu() | |
| v3_teacher_prob = round(float(probs[pred_idx].item()), 6) | |
| return payload, v3_teacher_label, v3_teacher_prob | |
| def classify_one(text: Optional[str]) -> str: | |
| raw = "" if text is None else str(text) | |
| stripped = raw.strip() | |
| parameterized_text = preprocess_text(stripped) | |
| try: | |
| tokenizer, model, device = _load_model() | |
| id2label = _id2label_map(model) | |
| cfg = _classification_config(model) | |
| num_labels = int(getattr(cfg, "num_labels", 0) or 0) if cfg else 0 | |
| inputs = tokenizer( | |
| parameterized_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding=True, | |
| ).to(device) | |
| with torch.inference_mode(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1).squeeze(0) | |
| payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts( | |
| stripped, parameterized_text, probs, id2label, num_labels | |
| ) | |
| body = { | |
| "v3_teacher_label": v3_teacher_label, | |
| "v3_teacher_prob": v3_teacher_prob, | |
| "v3_teacher_results": payload, | |
| } | |
| except Exception as exc: | |
| print("classify_one error:", exc) | |
| payload = { | |
| "input_preview": stripped[:500] + ("..." if len(stripped) > 500 else ""), | |
| "parameterized_text": parameterized_text, | |
| "preprocessing_status": "error", | |
| "message": str(exc), | |
| "predicted_index": None, | |
| "predicted_class": None, | |
| } | |
| body = { | |
| "v3_teacher_label": None, | |
| "v3_teacher_prob": None, | |
| "v3_teacher_results": payload, | |
| } | |
| return json.dumps(body, indent=2) | |
| def classify_batch( | |
| texts: List[str], | |
| batch_size: int = 32, | |
| progress_cb: Optional[Callable[[int, int], None]] = None, | |
| ) -> List[Tuple[str, str, Optional[float]]]: | |
| """Return per row: (v3_teacher_results JSON string, v3_teacher_label, v3_teacher_prob) for CSV columns.""" | |
| n = len(texts) | |
| out: List[Optional[Tuple[str, str, Optional[float]]]] = [None] * n | |
| if n == 0: | |
| return [] | |
| tokenizer, model, device = _load_model() | |
| id2label = _id2label_map(model) | |
| cfg = _classification_config(model) | |
| num_labels = int(getattr(cfg, "num_labels", 0) or 0) if cfg else 0 | |
| preprocessed: List[str] = [""] * n | |
| done = 0 | |
| for i, t in enumerate(texts): | |
| try: | |
| preprocessed[i] = preprocess_text("" if t is None else str(t)) | |
| except Exception as exc: | |
| print("preprocess error row", i, exc) | |
| out[i] = ( | |
| json.dumps( | |
| {"error": str(exc), "preprocessing_status": "error"}, | |
| indent=2, | |
| ), | |
| "", | |
| None, | |
| ) | |
| done += 1 | |
| if progress_cb: | |
| progress_cb(done, n) | |
| model_indices = [i for i in range(n) if out[i] is None] | |
| for start in range(0, len(model_indices), batch_size): | |
| chunk = model_indices[start : start + batch_size] | |
| batch_param = [preprocessed[i] for i in chunk] | |
| batch_raw = ["" if texts[i] is None else str(texts[i]) for i in chunk] | |
| try: | |
| inputs = tokenizer( | |
| batch_param, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding=True, | |
| ).to(device) | |
| with torch.inference_mode(): | |
| logits = model(**inputs).logits | |
| probs_b = torch.softmax(logits, dim=-1) | |
| for j, row_i in enumerate(chunk): | |
| payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts( | |
| batch_raw[j], | |
| batch_param[j], | |
| probs_b[j], | |
| id2label, | |
| num_labels, | |
| ) | |
| out[row_i] = (json.dumps(payload, indent=2), v3_teacher_label, v3_teacher_prob) | |
| done += 1 | |
| if progress_cb: | |
| progress_cb(done, n) | |
| except Exception as exc: | |
| print("batch forward error, falling back row-wise:", exc) | |
| for row_i in chunk: | |
| try: | |
| single = tokenizer( | |
| preprocessed[row_i], | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| padding=True, | |
| ).to(device) | |
| with torch.inference_mode(): | |
| logits = model(**single).logits | |
| probs = torch.softmax(logits, dim=-1).squeeze(0) | |
| raw_s = "" if texts[row_i] is None else str(texts[row_i]) | |
| payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts( | |
| raw_s, | |
| preprocessed[row_i], | |
| probs, | |
| id2label, | |
| num_labels, | |
| ) | |
| out[row_i] = (json.dumps(payload, indent=2), v3_teacher_label, v3_teacher_prob) | |
| except Exception as exc2: | |
| print("row error", row_i, exc2) | |
| out[row_i] = ( | |
| json.dumps( | |
| {"error": str(exc2), "preprocessing_status": "error"}, | |
| indent=2, | |
| ), | |
| "", | |
| None, | |
| ) | |
| done += 1 | |
| if progress_cb: | |
| progress_cb(done, n) | |
| return [ | |
| o if o is not None else ("{}", "", None) for o in out | |
| ] | |