File size: 2,623 Bytes
7801609
3c5d4d3
ac40d43
 
3c5d4d3
256cc53
65341f8
256cc53
3c5d4d3
 
ac40d43
256cc53
 
 
 
3c5d4d3
 
 
ac40d43
d2362d1
65341f8
 
ac40d43
 
f40c874
ac40d43
f40c874
 
 
ac40d43
f40c874
 
 
ac40d43
f40c874
 
256cc53
3c5d4d3
65341f8
3c5d4d3
ac40d43
3c5d4d3
 
 
 
 
 
 
 
ac40d43
3c5d4d3
 
 
ac40d43
 
 
3c5d4d3
 
 
 
 
ac40d43
 
 
 
 
 
3c5d4d3
 
 
 
ac40d43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoTokenizer, LlamaForCausalLM, LlamaConfig
import torch
import os
from safetensors.torch import load_file

TOKENIZER_NAME = "polyglots/Extended-Sinhala-LLaMA"
VOCAB_SIZE     = 139336

class EndpointHandler:
    def __init__(self, path=""):
        print("Loading tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            TOKENIZER_NAME,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print("Loading config...")
        config = LlamaConfig.from_pretrained(path)
        config.vocab_size = VOCAB_SIZE

        print("Building model at correct vocab size...")
        self.model = LlamaForCausalLM(config)

        print("Loading checkpoint shards...")
        weights = {}
        for f in sorted(os.listdir(path)):
            if f.endswith(".safetensors"):
                print(f"  Shard: {f}")
                weights.update(load_file(os.path.join(path, f)))

        missing, unexpected = self.model.load_state_dict(weights, strict=False)
        print(f"  Missing: {len(missing)} | Unexpected: {len(unexpected)}")

        self.model = self.model.to(torch.float16).to("cuda")
        self.model.config.pad_token_id = self.tokenizer.eos_token_id
        self.model.eval()
        print(f"Ready! Vocab: {self.model.config.vocab_size:,}")

    def __call__(self, data: dict) -> dict:
        inputs      = data.get("inputs", "")
        params      = data.get("parameters", {})
        max_tokens  = params.get("max_new_tokens", 400)
        temperature = params.get("temperature",    0.7)
        top_p       = params.get("top_p",          0.9)
        rep_penalty = params.get("repetition_penalty", 1.1)

        if not inputs:
            return {"error": "No input provided. Use the inputs key."}

        tokenized = self.tokenizer(
            inputs,
            return_tensors="pt",
            truncation=True,
            max_length=1024,
        ).to(self.model.device)

        with torch.no_grad():
            output_ids = self.model.generate(
                **tokenized,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=rep_penalty,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        new_tokens = output_ids[0][tokenized.input_ids.shape[1]:]
        decoded    = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
        return {"generated_text": decoded.strip()}