finbertteacher_v1 / inference.py
aimlresearch2023's picture
modified token reading
e8ced65
"""Model load and single/batch classification (no Gradio)."""
from __future__ import annotations
import json
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from huggingface_hub import login
from peft import PeftConfig, PeftModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from preprocess import preprocess_text
MODEL_ID = os.environ.get(
"MODEL_ID",
"apps1/wallet_bert_13_gpu",
)
MAX_LENGTH = 512
_tokenizer = None
_model = None
_device: Optional[torch.device] = None
def _ensure_hf_login() -> None:
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if not token:
raise RuntimeError(
"HF_TOKEN or HUGGINGFACE_HUB_TOKEN is not set. Add HF_TOKEN under Space repository secrets."
)
login(token=token, add_to_git_credential=False)
def _load_model() -> Tuple[Any, Any, torch.device]:
global _tokenizer, _model, _device
if _model is not None and _tokenizer is not None and _device is not None:
return _tokenizer, _model, _device
_ensure_hf_login()
print("Loading PEFT adapter from Hub:", MODEL_ID)
peft_config = PeftConfig.from_pretrained(MODEL_ID)
base_id = peft_config.base_model_name_or_path
print("Base model:", base_id)
tokenizer = AutoTokenizer.from_pretrained(base_id, trust_remote_code=True)
base = AutoModelForSequenceClassification.from_pretrained(
base_id,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, MODEL_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
_tokenizer = tokenizer
_model = model
_device = device
print("Model ready on device:", device)
return tokenizer, model, device
def _classification_config(model: Any) -> Any:
inner = getattr(getattr(model, "base_model", None), "model", None)
if inner is not None and getattr(inner, "config", None) is not None:
return inner.config
return getattr(model, "config", None)
def _id2label_map(model: Any) -> Dict[int, str]:
cfg = _classification_config(model)
raw = getattr(cfg, "id2label", None) or {} if cfg is not None else {}
out: Dict[int, str] = {}
for k, v in raw.items():
try:
ik = int(k)
except (TypeError, ValueError):
continue
out[ik] = str(v)
if not out:
n = int(getattr(cfg, "num_labels", 0) or 0) if cfg is not None else 0
for i in range(n):
out[i] = f"class_{i}"
return out
def _prob_key(label: str) -> str:
safe = "".join(ch if ch.isalnum() or ch == "_" else "_" for ch in label)
return f"P_{safe}"
def _canonical_v3_teacher_label(pred_idx: int, raw_label: str, num_labels: int) -> str:
s = raw_label.lower().strip().replace(" ", "_").replace("-", "_")
if "non_bill" in s or s == "nonbill":
return "non_bill"
if s == "bill" or (s.endswith("bill") and "non" not in s):
return "bill"
if num_labels == 2:
return "non_bill" if pred_idx == 0 else "bill"
return "non_bill" if pred_idx == 0 else "bill"
def _build_payload(
raw_text: str,
parameterized_text: str,
probs_1d: torch.Tensor,
id2label: Dict[int, str],
) -> Dict[str, Any]:
probs = probs_1d.detach().float().cpu()
pred_idx = int(torch.argmax(probs).item())
payload: Dict[str, Any] = {
"input_preview": raw_text[:500] + ("..." if len(raw_text) > 500 else ""),
"parameterized_text": parameterized_text,
"preprocessing_status": "ok",
"predicted_index": pred_idx,
"predicted_class": id2label.get(pred_idx, str(pred_idx)),
}
for i in range(probs.numel()):
label = id2label.get(i, str(i))
payload[_prob_key(label)] = round(float(probs[i].item()), 6)
return payload
def _classify_success_parts(
raw_text: str,
parameterized_text: str,
probs_1d: torch.Tensor,
id2label: Dict[int, str],
num_labels: int,
) -> Tuple[Dict[str, Any], str, float]:
payload = _build_payload(raw_text, parameterized_text, probs_1d, id2label)
pred_idx = int(payload["predicted_index"])
v3_teacher_label = _canonical_v3_teacher_label(
pred_idx, str(id2label.get(pred_idx, "")), num_labels
)
probs = probs_1d.detach().float().cpu()
v3_teacher_prob = round(float(probs[pred_idx].item()), 6)
return payload, v3_teacher_label, v3_teacher_prob
def classify_one(text: Optional[str]) -> str:
raw = "" if text is None else str(text)
stripped = raw.strip()
parameterized_text = preprocess_text(stripped)
try:
tokenizer, model, device = _load_model()
id2label = _id2label_map(model)
cfg = _classification_config(model)
num_labels = int(getattr(cfg, "num_labels", 0) or 0) if cfg else 0
inputs = tokenizer(
parameterized_text,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding=True,
).to(device)
with torch.inference_mode():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1).squeeze(0)
payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts(
stripped, parameterized_text, probs, id2label, num_labels
)
body = {
"v3_teacher_label": v3_teacher_label,
"v3_teacher_prob": v3_teacher_prob,
"v3_teacher_results": payload,
}
except Exception as exc:
print("classify_one error:", exc)
payload = {
"input_preview": stripped[:500] + ("..." if len(stripped) > 500 else ""),
"parameterized_text": parameterized_text,
"preprocessing_status": "error",
"message": str(exc),
"predicted_index": None,
"predicted_class": None,
}
body = {
"v3_teacher_label": None,
"v3_teacher_prob": None,
"v3_teacher_results": payload,
}
return json.dumps(body, indent=2)
def classify_batch(
texts: List[str],
batch_size: int = 32,
progress_cb: Optional[Callable[[int, int], None]] = None,
) -> List[Tuple[str, str, Optional[float]]]:
"""Return per row: (v3_teacher_results JSON string, v3_teacher_label, v3_teacher_prob) for CSV columns."""
n = len(texts)
out: List[Optional[Tuple[str, str, Optional[float]]]] = [None] * n
if n == 0:
return []
tokenizer, model, device = _load_model()
id2label = _id2label_map(model)
cfg = _classification_config(model)
num_labels = int(getattr(cfg, "num_labels", 0) or 0) if cfg else 0
preprocessed: List[str] = [""] * n
done = 0
for i, t in enumerate(texts):
try:
preprocessed[i] = preprocess_text("" if t is None else str(t))
except Exception as exc:
print("preprocess error row", i, exc)
out[i] = (
json.dumps(
{"error": str(exc), "preprocessing_status": "error"},
indent=2,
),
"",
None,
)
done += 1
if progress_cb:
progress_cb(done, n)
model_indices = [i for i in range(n) if out[i] is None]
for start in range(0, len(model_indices), batch_size):
chunk = model_indices[start : start + batch_size]
batch_param = [preprocessed[i] for i in chunk]
batch_raw = ["" if texts[i] is None else str(texts[i]) for i in chunk]
try:
inputs = tokenizer(
batch_param,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding=True,
).to(device)
with torch.inference_mode():
logits = model(**inputs).logits
probs_b = torch.softmax(logits, dim=-1)
for j, row_i in enumerate(chunk):
payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts(
batch_raw[j],
batch_param[j],
probs_b[j],
id2label,
num_labels,
)
out[row_i] = (json.dumps(payload, indent=2), v3_teacher_label, v3_teacher_prob)
done += 1
if progress_cb:
progress_cb(done, n)
except Exception as exc:
print("batch forward error, falling back row-wise:", exc)
for row_i in chunk:
try:
single = tokenizer(
preprocessed[row_i],
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding=True,
).to(device)
with torch.inference_mode():
logits = model(**single).logits
probs = torch.softmax(logits, dim=-1).squeeze(0)
raw_s = "" if texts[row_i] is None else str(texts[row_i])
payload, v3_teacher_label, v3_teacher_prob = _classify_success_parts(
raw_s,
preprocessed[row_i],
probs,
id2label,
num_labels,
)
out[row_i] = (json.dumps(payload, indent=2), v3_teacher_label, v3_teacher_prob)
except Exception as exc2:
print("row error", row_i, exc2)
out[row_i] = (
json.dumps(
{"error": str(exc2), "preprocessing_status": "error"},
indent=2,
),
"",
None,
)
done += 1
if progress_cb:
progress_cb(done, n)
return [
o if o is not None else ("{}", "", None) for o in out
]