sinllama-mcq-kapture / handler.py
itsjorigo's picture
Upload handler.py with huggingface_hub
ac40d43 verified
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaConfig
import torch
import os
from safetensors.torch import load_file
TOKENIZER_NAME = "polyglots/Extended-Sinhala-LLaMA"
VOCAB_SIZE = 139336
class EndpointHandler:
def __init__(self, path=""):
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
TOKENIZER_NAME,
trust_remote_code=True,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print("Loading config...")
config = LlamaConfig.from_pretrained(path)
config.vocab_size = VOCAB_SIZE
print("Building model at correct vocab size...")
self.model = LlamaForCausalLM(config)
print("Loading checkpoint shards...")
weights = {}
for f in sorted(os.listdir(path)):
if f.endswith(".safetensors"):
print(f" Shard: {f}")
weights.update(load_file(os.path.join(path, f)))
missing, unexpected = self.model.load_state_dict(weights, strict=False)
print(f" Missing: {len(missing)} | Unexpected: {len(unexpected)}")
self.model = self.model.to(torch.float16).to("cuda")
self.model.config.pad_token_id = self.tokenizer.eos_token_id
self.model.eval()
print(f"Ready! Vocab: {self.model.config.vocab_size:,}")
def __call__(self, data: dict) -> dict:
inputs = data.get("inputs", "")
params = data.get("parameters", {})
max_tokens = params.get("max_new_tokens", 400)
temperature = params.get("temperature", 0.7)
top_p = params.get("top_p", 0.9)
rep_penalty = params.get("repetition_penalty", 1.1)
if not inputs:
return {"error": "No input provided. Use the inputs key."}
tokenized = self.tokenizer(
inputs,
return_tensors="pt",
truncation=True,
max_length=1024,
).to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**tokenized,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=rep_penalty,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
new_tokens = output_ids[0][tokenized.input_ids.shape[1]:]
decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return {"generated_text": decoded.strip()}