| from transformers import AutoTokenizer, LlamaForCausalLM, LlamaConfig |
| import torch |
| import os |
| from safetensors.torch import load_file |
|
|
| TOKENIZER_NAME = "polyglots/Extended-Sinhala-LLaMA" |
| VOCAB_SIZE = 139336 |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| print("Loading tokenizer...") |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| TOKENIZER_NAME, |
| trust_remote_code=True, |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| print("Loading config...") |
| config = LlamaConfig.from_pretrained(path) |
| config.vocab_size = VOCAB_SIZE |
|
|
| print("Building model at correct vocab size...") |
| self.model = LlamaForCausalLM(config) |
|
|
| print("Loading checkpoint shards...") |
| weights = {} |
| for f in sorted(os.listdir(path)): |
| if f.endswith(".safetensors"): |
| print(f" Shard: {f}") |
| weights.update(load_file(os.path.join(path, f))) |
|
|
| missing, unexpected = self.model.load_state_dict(weights, strict=False) |
| print(f" Missing: {len(missing)} | Unexpected: {len(unexpected)}") |
|
|
| self.model = self.model.to(torch.float16).to("cuda") |
| self.model.config.pad_token_id = self.tokenizer.eos_token_id |
| self.model.eval() |
| print(f"Ready! Vocab: {self.model.config.vocab_size:,}") |
|
|
| def __call__(self, data: dict) -> dict: |
| inputs = data.get("inputs", "") |
| params = data.get("parameters", {}) |
| max_tokens = params.get("max_new_tokens", 400) |
| temperature = params.get("temperature", 0.7) |
| top_p = params.get("top_p", 0.9) |
| rep_penalty = params.get("repetition_penalty", 1.1) |
|
|
| if not inputs: |
| return {"error": "No input provided. Use the inputs key."} |
|
|
| tokenized = self.tokenizer( |
| inputs, |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024, |
| ).to(self.model.device) |
|
|
| with torch.no_grad(): |
| output_ids = self.model.generate( |
| **tokenized, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| repetition_penalty=rep_penalty, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| ) |
|
|
| new_tokens = output_ids[0][tokenized.input_ids.shape[1]:] |
| decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| return {"generated_text": decoded.strip()} |
|
|