""" HuggingFace Inference Endpoint custom handler for SinLlama-MCQ. Model stack: Llama base → SinLlama LoRA (merge) → MCQ LoRA → inference Environment variables (set in the HF Endpoint dashboard): SINLLAMA_REPO SinLlama HF repo ID (default: polyglots/SinLlama_v01) BASE_MODEL Llama base HF repo ID (auto-read from SinLlama config if unset) HF_TOKEN HF token for gated/private repos TEMPERATURE Generation temperature (default: 0.7) MAX_NEW_TOKENS Max tokens to generate (default: 300) Request: {"inputs": ""} Response: [{"generated_text": "...", "mcq": "...", "valid": true}] """ import json import os import re import logging import traceback import torch from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel, PeftConfig logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Use print() in addition to logger — HF toolkit may suppress the logger # before the handler module is imported, but print() always reaches stdout. def _log(msg: str) -> None: print(msg, flush=True) logger.info(msg) _PROMPT_TEMPLATE = ( "පහත ඉතිහාස ඡේදය කියවා, ඒ ගැන බහු-විකල්ප ප්‍රශ්නයක් සාදන්න.\n\n" "ඡේදය: {passage}\n\n" "MCQ:" ) def _normalize_mcq(text: str) -> str | None: text = re.sub(r"\b([ABCD])[.:]", r"\1)", text) for tag in ["A)", "B)", "C)", "D)", "නිවැරදි පිළිතුර:"]: text = re.sub(rf"(? None: """Remove unknown tokenizer_class entries that crash AutoTokenizer.""" tc_path = Path(path) / "tokenizer_config.json" if not tc_path.exists(): return with open(tc_path, encoding="utf-8") as f: tc = json.load(f) known = {None, "LlamaTokenizer", "LlamaTokenizerFast", "PreTrainedTokenizer", "PreTrainedTokenizerFast"} if tc.get("tokenizer_class") not in known: logger.info("Patching tokenizer_config.json: dropping tokenizer_class=%r", tc.pop("tokenizer_class")) with open(tc_path, "w", encoding="utf-8") as f: json.dump(tc, f, ensure_ascii=False, indent=2) def _read_adapter_config(path: str, hf_token: str | None) -> dict: """Read adapter_config.json, logging its full contents for debugging.""" try: cfg = PeftConfig.from_pretrained(path, token=hf_token) logger.info("adapter_config from %s: base_model=%r, peft_type=%r", path, cfg.base_model_name_or_path, cfg.peft_type) return cfg except Exception: logger.error("Failed to read adapter config from %s:\n%s", path, traceback.format_exc()) raise class EndpointHandler: def __init__(self, path: str = ""): _log("=" * 60) _log(f"EndpointHandler.__init__ path={path!r}") _log("=" * 60) sinllama_repo = os.environ.get("SINLLAMA_REPO", "polyglots/SinLlama_v01") hf_token = os.environ.get("HF_TOKEN") _log(f"SINLLAMA_REPO={sinllama_repo!r} HF_TOKEN={'set' if hf_token else 'NOT SET'}") # List repo files so we can verify what was actually deployed repo_files = sorted(str(p.name) for p in Path(path).iterdir()) if path else [] _log(f"Repo files: {repo_files}") # ── 1. Tokenizer ────────────────────────────────────────────────────── _log("STEP 1: patch tokenizer_config.json + load tokenizer") _patch_tokenizer_config(path) self.tokenizer = AutoTokenizer.from_pretrained( path, use_fast=False, token=hf_token ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token sinllama_vocab = len(self.tokenizer) _log(f"STEP 1 OK — vocab size {sinllama_vocab}") # ── 2. Resolve Llama base model ID ──────────────────────────────────── _log("STEP 2: resolve base model ID") base_model_id = os.environ.get("BASE_MODEL") if base_model_id: _log(f"BASE_MODEL from env: {base_model_id!r}") else: _log("BASE_MODEL not set — reading from SinLlama adapter config") sinllama_cfg = _read_adapter_config(sinllama_repo, hf_token) base_model_id = sinllama_cfg.base_model_name_or_path _log(f"Base model from SinLlama config: {base_model_id!r}") _log(f"STEP 2 OK — base_model_id={base_model_id!r}") # ── 3. Load Llama base ──────────────────────────────────────────────── _log(f"STEP 3: load base model {base_model_id!r}") base = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa", token=hf_token, ) base.resize_token_embeddings(sinllama_vocab) _log(f"STEP 3 OK — embeddings resized to {sinllama_vocab}") # ── 4. Load + merge SinLlama adapter ───────────────────────────────── _log(f"STEP 4: load SinLlama adapter from {sinllama_repo!r}") sinllama = PeftModel.from_pretrained( base, sinllama_repo, is_trainable=False, token=hf_token ) sinllama = sinllama.merge_and_unload() _log("STEP 4 OK — SinLlama merged") # ── 5. Load MCQ adapter ─────────────────────────────────────────────── _log(f"STEP 5: load MCQ adapter from {path!r}") _read_adapter_config(path, hf_token) # logs base_model_name_or_path self.model = PeftModel.from_pretrained( sinllama, path, is_trainable=False, token=hf_token ) self.model.eval() _log("STEP 5 OK — EndpointHandler ready") self._default_temperature = float(os.environ.get("TEMPERATURE", 0.7)) self._default_max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", 300)) def __call__(self, data: dict) -> list[dict]: passage = (data.get("inputs") or data.get("passage") or "").strip() if not passage: return [{"error": "No passage provided. Send {\"inputs\": \"\"}"}] params = data.get("parameters") or {} temperature = float(params.get("temperature", self._default_temperature)) max_new_tokens = int(params.get("max_new_tokens", self._default_max_new_tokens)) prompt = _PROMPT_TEMPLATE.format(passage=passage) inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) try: with torch.no_grad(): output = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0, repetition_penalty=1.1, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) except Exception as e: logger.exception("Generation failed") return [{"error": f"Generation error: {e}"}] new_ids = output[0][inputs.input_ids.shape[1]:] raw = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() mcq = _normalize_mcq(raw) return [{ "generated_text": mcq if mcq is not None else raw, "mcq": mcq, "valid": mcq is not None, }]