from transformers import AutoTokenizer, LlamaForCausalLM, LlamaConfig import torch import os from safetensors.torch import load_file TOKENIZER_NAME = "polyglots/Extended-Sinhala-LLaMA" VOCAB_SIZE = 139336 class EndpointHandler: def __init__(self, path=""): print("Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( TOKENIZER_NAME, trust_remote_code=True, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token print("Loading config...") config = LlamaConfig.from_pretrained(path) config.vocab_size = VOCAB_SIZE print("Building model at correct vocab size...") self.model = LlamaForCausalLM(config) print("Loading checkpoint shards...") weights = {} for f in sorted(os.listdir(path)): if f.endswith(".safetensors"): print(f" Shard: {f}") weights.update(load_file(os.path.join(path, f))) missing, unexpected = self.model.load_state_dict(weights, strict=False) print(f" Missing: {len(missing)} | Unexpected: {len(unexpected)}") self.model = self.model.to(torch.float16).to("cuda") self.model.config.pad_token_id = self.tokenizer.eos_token_id self.model.eval() print(f"Ready! Vocab: {self.model.config.vocab_size:,}") def __call__(self, data: dict) -> dict: inputs = data.get("inputs", "") params = data.get("parameters", {}) max_tokens = params.get("max_new_tokens", 400) temperature = params.get("temperature", 0.7) top_p = params.get("top_p", 0.9) rep_penalty = params.get("repetition_penalty", 1.1) if not inputs: return {"error": "No input provided. Use the inputs key."} tokenized = self.tokenizer( inputs, return_tensors="pt", truncation=True, max_length=1024, ).to(self.model.device) with torch.no_grad(): output_ids = self.model.generate( **tokenized, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=rep_penalty, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) new_tokens = output_ids[0][tokenized.input_ids.shape[1]:] decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True) return {"generated_text": decoded.strip()}