| """ |
| HuggingFace Inference Endpoints custom handler for SinLlama-MCQ. |
| Upload this file as `handler.py` to itsjorigo/sinllama-mcq-3.0 on HuggingFace. |
| """ |
|
|
| import re |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
|
|
| BASE_MODEL = "meta-llama/Meta-Llama-3-8B" |
| SINLLAMA_ID = "polyglots/SinLlama_v01" |
|
|
| PROMPT_TEMPLATE = ( |
| "පහත ඉතිහාස ඡේදය කියවා, ඒ ගැන බහු-විකල්ප ප්රශ්නයක් සාදන්න.\n\n" |
| "ඡේදය: {passage}\n\n" |
| "{entity_block}" |
| "MCQ:" |
| ) |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| print("Loading tokenizer...") |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(SINLLAMA_ID, trust_remote_code=True) |
| vocab_size = len(self.tokenizer) |
|
|
| |
| |
| print("Loading base model in float16...") |
| base = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| attn_implementation="sdpa", |
| ) |
| |
| |
| base.resize_token_embeddings(vocab_size, mean_resizing=False) |
|
|
| |
| print("Loading and merging SinLlama adapter...") |
| sinllama = PeftModel.from_pretrained(base, SINLLAMA_ID, is_trainable=False) |
| merged = sinllama.merge_and_unload() |
|
|
| print("Loading MCQ fine-tune adapter...") |
| self.model = PeftModel.from_pretrained(merged, path, is_trainable=False) |
| self.model.eval() |
| print("Model ready.") |
|
|
| def __call__(self, data: dict) -> dict: |
| |
| inputs = data.get("inputs", {}) |
| passage = inputs.get("passage", "") |
| entity_block = inputs.get("entity_block", "") |
|
|
| prompt = PROMPT_TEMPLATE.format( |
| passage=passage.strip(), |
| entity_block=entity_block, |
| ) |
|
|
| enc = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
| with torch.no_grad(): |
| out = self.model.generate( |
| **enc, |
| max_new_tokens=280, |
| temperature=0.7, |
| do_sample=True, |
| repetition_penalty=1.1, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.pad_token_id, |
| ) |
|
|
| new_ids = out[0][enc.input_ids.shape[1]:] |
| text = self.tokenizer.decode(new_ids, skip_special_tokens=True).strip() |
|
|
| for tag in ["A)", "B)", "C)", "D)", "නිවැරදි පිළිතුර:"]: |
| text = re.sub(rf"(?<!\n)({re.escape(tag)})", r"\n\1", text) |
|
|
| return {"mcq": text.strip()} |
|
|