sinllama-mcq-4.0 / handler.py
itsjorigo's picture
Update handler.py
51d5f26 verified
"""
HuggingFace Inference Endpoint custom handler for SinLlama-MCQ.
Model stack: Llama base β†’ SinLlama LoRA (merge) β†’ MCQ LoRA β†’ inference
Environment variables (set in the HF Endpoint dashboard):
SINLLAMA_REPO SinLlama HF repo ID (default: polyglots/SinLlama_v01)
BASE_MODEL Llama base HF repo ID (auto-read from SinLlama config if unset)
HF_TOKEN HF token for gated/private repos
TEMPERATURE Generation temperature (default: 0.7)
MAX_NEW_TOKENS Max tokens to generate (default: 300)
Request: {"inputs": "<Sinhala passage>"}
Response: [{"generated_text": "...", "mcq": "...", "valid": true}]
"""
import json
import os
import re
import logging
import traceback
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Use print() in addition to logger β€” HF toolkit may suppress the logger
# before the handler module is imported, but print() always reaches stdout.
def _log(msg: str) -> None:
print(msg, flush=True)
logger.info(msg)
_PROMPT_TEMPLATE = (
"ΰΆ΄ΰ·„ΰΆ­ ࢉࢭිහාස ࢑ේࢯࢺ ΰΆšΰ·’ΰΆΊΰ·€ΰ·, ΰΆ’ ࢜ැࢱ ΰΆΆΰ·„ΰ·”-ΰ·€ΰ·’ΰΆšΰΆ½ΰ·ŠΰΆ΄ ΰΆ΄ΰ·Šβ€ΰΆ»ΰ·ΰ·ŠΰΆ±ΰΆΊΰΆšΰ·Š ΰ·ƒΰ·ΰΆ―ΰΆ±ΰ·ŠΰΆ±.\n\n"
"࢑ේࢯࢺ: {passage}\n\n"
"MCQ:"
)
def _normalize_mcq(text: str) -> str | None:
text = re.sub(r"\b([ABCD])[.:]", r"\1)", text)
for tag in ["A)", "B)", "C)", "D)", "ࢱිවැࢻࢯි ΰΆ΄ΰ·’ΰ·…ΰ·’ΰΆ­ΰ·”ΰΆ»:"]:
text = re.sub(rf"(?<!\n)({re.escape(tag)})", r"\n\1", text)
text = text.strip()
if "ΰΆ΄ΰ·Šβ€ΰΆ»ΰ·ΰ·ŠΰΆ±ΰΆΊ:" not in text:
lines = text.splitlines()
if lines:
text = "ΰΆ΄ΰ·Šβ€ΰΆ»ΰ·ΰ·ŠΰΆ±ΰΆΊ: " + lines[0].strip() + "\n" + "\n".join(lines[1:])
if not all(f"{l})" in text for l in "ABCD"):
return None
if not re.search(r"ࢱිවැࢻࢯි ΰΆ΄ΰ·’ΰ·…ΰ·’ΰΆ­ΰ·”ΰΆ»:\s*[ABCD]", text):
return None
return text.strip()
def _patch_tokenizer_config(path: str) -> None:
"""Remove unknown tokenizer_class entries that crash AutoTokenizer."""
tc_path = Path(path) / "tokenizer_config.json"
if not tc_path.exists():
return
with open(tc_path, encoding="utf-8") as f:
tc = json.load(f)
known = {None, "LlamaTokenizer", "LlamaTokenizerFast",
"PreTrainedTokenizer", "PreTrainedTokenizerFast"}
if tc.get("tokenizer_class") not in known:
logger.info("Patching tokenizer_config.json: dropping tokenizer_class=%r",
tc.pop("tokenizer_class"))
with open(tc_path, "w", encoding="utf-8") as f:
json.dump(tc, f, ensure_ascii=False, indent=2)
def _read_adapter_config(path: str, hf_token: str | None) -> dict:
"""Read adapter_config.json, logging its full contents for debugging."""
try:
cfg = PeftConfig.from_pretrained(path, token=hf_token)
logger.info("adapter_config from %s: base_model=%r, peft_type=%r",
path, cfg.base_model_name_or_path, cfg.peft_type)
return cfg
except Exception:
logger.error("Failed to read adapter config from %s:\n%s",
path, traceback.format_exc())
raise
class EndpointHandler:
def __init__(self, path: str = ""):
_log("=" * 60)
_log(f"EndpointHandler.__init__ path={path!r}")
_log("=" * 60)
sinllama_repo = os.environ.get("SINLLAMA_REPO", "polyglots/SinLlama_v01")
hf_token = os.environ.get("HF_TOKEN")
_log(f"SINLLAMA_REPO={sinllama_repo!r} HF_TOKEN={'set' if hf_token else 'NOT SET'}")
# List repo files so we can verify what was actually deployed
repo_files = sorted(str(p.name) for p in Path(path).iterdir()) if path else []
_log(f"Repo files: {repo_files}")
# ── 1. Tokenizer ──────────────────────────────────────────────────────
_log("STEP 1: patch tokenizer_config.json + load tokenizer")
_patch_tokenizer_config(path)
self.tokenizer = AutoTokenizer.from_pretrained(
path, use_fast=False, token=hf_token
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
sinllama_vocab = len(self.tokenizer)
_log(f"STEP 1 OK β€” vocab size {sinllama_vocab}")
# ── 2. Resolve Llama base model ID ────────────────────────────────────
_log("STEP 2: resolve base model ID")
base_model_id = os.environ.get("BASE_MODEL")
if base_model_id:
_log(f"BASE_MODEL from env: {base_model_id!r}")
else:
_log("BASE_MODEL not set β€” reading from SinLlama adapter config")
sinllama_cfg = _read_adapter_config(sinllama_repo, hf_token)
base_model_id = sinllama_cfg.base_model_name_or_path
_log(f"Base model from SinLlama config: {base_model_id!r}")
_log(f"STEP 2 OK β€” base_model_id={base_model_id!r}")
# ── 3. Load Llama base ────────────────────────────────────────────────
_log(f"STEP 3: load base model {base_model_id!r}")
base = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
token=hf_token,
)
base.resize_token_embeddings(sinllama_vocab)
_log(f"STEP 3 OK β€” embeddings resized to {sinllama_vocab}")
# ── 4. Load + merge SinLlama adapter ─────────────────────────────────
_log(f"STEP 4: load SinLlama adapter from {sinllama_repo!r}")
sinllama = PeftModel.from_pretrained(
base, sinllama_repo, is_trainable=False, token=hf_token
)
sinllama = sinllama.merge_and_unload()
_log("STEP 4 OK β€” SinLlama merged")
# ── 5. Load MCQ adapter ───────────────────────────────────────────────
_log(f"STEP 5: load MCQ adapter from {path!r}")
_read_adapter_config(path, hf_token) # logs base_model_name_or_path
self.model = PeftModel.from_pretrained(
sinllama, path, is_trainable=False, token=hf_token
)
self.model.eval()
_log("STEP 5 OK β€” EndpointHandler ready")
self._default_temperature = float(os.environ.get("TEMPERATURE", 0.7))
self._default_max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", 300))
def __call__(self, data: dict) -> list[dict]:
passage = (data.get("inputs") or data.get("passage") or "").strip()
if not passage:
return [{"error": "No passage provided. Send {\"inputs\": \"<passage>\"}"}]
params = data.get("parameters") or {}
temperature = float(params.get("temperature", self._default_temperature))
max_new_tokens = int(params.get("max_new_tokens", self._default_max_new_tokens))
prompt = _PROMPT_TEMPLATE.format(passage=passage)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
try:
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
repetition_penalty=1.1,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
except Exception as e:
logger.exception("Generation failed")
return [{"error": f"Generation error: {e}"}]
new_ids = output[0][inputs.input_ids.shape[1]:]
raw = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
mcq = _normalize_mcq(raw)
return [{
"generated_text": mcq if mcq is not None else raw,
"mcq": mcq,
"valid": mcq is not None,
}]