File size: 2,623 Bytes
7801609 3c5d4d3 ac40d43 3c5d4d3 256cc53 65341f8 256cc53 3c5d4d3 ac40d43 256cc53 3c5d4d3 ac40d43 d2362d1 65341f8 ac40d43 f40c874 ac40d43 f40c874 ac40d43 f40c874 ac40d43 f40c874 256cc53 3c5d4d3 65341f8 3c5d4d3 ac40d43 3c5d4d3 ac40d43 3c5d4d3 ac40d43 3c5d4d3 ac40d43 3c5d4d3 ac40d43 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | from transformers import AutoTokenizer, LlamaForCausalLM, LlamaConfig
import torch
import os
from safetensors.torch import load_file
TOKENIZER_NAME = "polyglots/Extended-Sinhala-LLaMA"
VOCAB_SIZE = 139336
class EndpointHandler:
def __init__(self, path=""):
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
TOKENIZER_NAME,
trust_remote_code=True,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print("Loading config...")
config = LlamaConfig.from_pretrained(path)
config.vocab_size = VOCAB_SIZE
print("Building model at correct vocab size...")
self.model = LlamaForCausalLM(config)
print("Loading checkpoint shards...")
weights = {}
for f in sorted(os.listdir(path)):
if f.endswith(".safetensors"):
print(f" Shard: {f}")
weights.update(load_file(os.path.join(path, f)))
missing, unexpected = self.model.load_state_dict(weights, strict=False)
print(f" Missing: {len(missing)} | Unexpected: {len(unexpected)}")
self.model = self.model.to(torch.float16).to("cuda")
self.model.config.pad_token_id = self.tokenizer.eos_token_id
self.model.eval()
print(f"Ready! Vocab: {self.model.config.vocab_size:,}")
def __call__(self, data: dict) -> dict:
inputs = data.get("inputs", "")
params = data.get("parameters", {})
max_tokens = params.get("max_new_tokens", 400)
temperature = params.get("temperature", 0.7)
top_p = params.get("top_p", 0.9)
rep_penalty = params.get("repetition_penalty", 1.1)
if not inputs:
return {"error": "No input provided. Use the inputs key."}
tokenized = self.tokenizer(
inputs,
return_tensors="pt",
truncation=True,
max_length=1024,
).to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**tokenized,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=rep_penalty,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
new_tokens = output_ids[0][tokenized.input_ids.shape[1]:]
decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return {"generated_text": decoded.strip()}
|