sinllama-mcq-3.0 / handler.py
itsjorigo's picture
Update handler.py
a0de84d verified
"""
HuggingFace Inference Endpoints custom handler for SinLlama-MCQ.
Upload this file as `handler.py` to itsjorigo/sinllama-mcq-3.0 on HuggingFace.
"""
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
BASE_MODEL = "meta-llama/Meta-Llama-3-8B"
SINLLAMA_ID = "polyglots/SinLlama_v01"
PROMPT_TEMPLATE = (
"පහත ඉතිහාස ඡේදය කියවා, ඒ ගැන බහු-විකල්ප ප්‍රශ්නයක් සාදන්න.\n\n"
"ඡේදය: {passage}\n\n"
"{entity_block}"
"MCQ:"
)
class EndpointHandler:
def __init__(self, path=""):
# path = local directory where HF downloaded itsjorigo/sinllama-mcq-3.0
print("Loading tokenizer...")
# Load tokenizer from SinLlama repo — it defines the custom TokenizersBackend class
self.tokenizer = AutoTokenizer.from_pretrained(SINLLAMA_ID, trust_remote_code=True)
vocab_size = len(self.tokenizer)
# Load in float16 — required for merge_and_unload (quantized models can't be merged)
# A10G has 24 GB VRAM which fits float16 8B model (~16 GB) comfortably
print("Loading base model in float16...")
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa",
)
# mean_resizing=False avoids holding 2x embedding matrix in VRAM during resize.
# Safe here because SinLlama adapter contains the correct trained embeddings.
base.resize_token_embeddings(vocab_size, mean_resizing=False)
# Merge SinLlama into base so the MCQ adapter sees a plain model (not stacked PeftModel)
print("Loading and merging SinLlama adapter...")
sinllama = PeftModel.from_pretrained(base, SINLLAMA_ID, is_trainable=False)
merged = sinllama.merge_and_unload()
print("Loading MCQ fine-tune adapter...")
self.model = PeftModel.from_pretrained(merged, path, is_trainable=False)
self.model.eval()
print("Model ready.")
def __call__(self, data: dict) -> dict:
# data["inputs"] = { "passage": "...", "entity_block": "..." }
inputs = data.get("inputs", {})
passage = inputs.get("passage", "")
entity_block = inputs.get("entity_block", "")
prompt = PROMPT_TEMPLATE.format(
passage=passage.strip(),
entity_block=entity_block,
)
enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**enc,
max_new_tokens=280,
temperature=0.7,
do_sample=True,
repetition_penalty=1.1,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
new_ids = out[0][enc.input_ids.shape[1]:]
text = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip()
for tag in ["A)", "B)", "C)", "D)", "නිවැරදි පිළිතුර:"]:
text = re.sub(rf"(?<!\n)({re.escape(tag)})", r"\n\1", text)
return {"mcq": text.strip()}