""" HuggingFace Inference Endpoints custom handler for SinLlama-MCQ. Upload this file as `handler.py` to itsjorigo/sinllama-mcq-3.0 on HuggingFace. """ import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel BASE_MODEL = "meta-llama/Meta-Llama-3-8B" SINLLAMA_ID = "polyglots/SinLlama_v01" PROMPT_TEMPLATE = ( "පහත ඉතිහාස ඡේදය කියවා, ඒ ගැන බහු-විකල්ප ප්‍රශ්නයක් සාදන්න.\n\n" "ඡේදය: {passage}\n\n" "{entity_block}" "MCQ:" ) class EndpointHandler: def __init__(self, path=""): # path = local directory where HF downloaded itsjorigo/sinllama-mcq-3.0 print("Loading tokenizer...") # Load tokenizer from SinLlama repo — it defines the custom TokenizersBackend class self.tokenizer = AutoTokenizer.from_pretrained(SINLLAMA_ID, trust_remote_code=True) vocab_size = len(self.tokenizer) # Load in float16 — required for merge_and_unload (quantized models can't be merged) # A10G has 24 GB VRAM which fits float16 8B model (~16 GB) comfortably print("Loading base model in float16...") base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, device_map="auto", attn_implementation="sdpa", ) # mean_resizing=False avoids holding 2x embedding matrix in VRAM during resize. # Safe here because SinLlama adapter contains the correct trained embeddings. base.resize_token_embeddings(vocab_size, mean_resizing=False) # Merge SinLlama into base so the MCQ adapter sees a plain model (not stacked PeftModel) print("Loading and merging SinLlama adapter...") sinllama = PeftModel.from_pretrained(base, SINLLAMA_ID, is_trainable=False) merged = sinllama.merge_and_unload() print("Loading MCQ fine-tune adapter...") self.model = PeftModel.from_pretrained(merged, path, is_trainable=False) self.model.eval() print("Model ready.") def __call__(self, data: dict) -> dict: # data["inputs"] = { "passage": "...", "entity_block": "..." } inputs = data.get("inputs", {}) passage = inputs.get("passage", "") entity_block = inputs.get("entity_block", "") prompt = PROMPT_TEMPLATE.format( passage=passage.strip(), entity_block=entity_block, ) enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) with torch.no_grad(): out = self.model.generate( **enc, max_new_tokens=280, temperature=0.7, do_sample=True, repetition_penalty=1.1, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) new_ids = out[0][enc.input_ids.shape[1]:] text = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() for tag in ["A)", "B)", "C)", "D)", "නිවැරදි පිළිතුර:"]: text = re.sub(rf"(?