itsjorigo commited on
Commit
f40c874
Β·
verified Β·
1 Parent(s): 7801609

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +33 -29
handler.py CHANGED
@@ -21,26 +21,34 @@ class EndpointHandler:
21
  # the 139336-vocab checkpoint weights
22
  print(f"Patching config vocab_size to {VOCAB_SIZE:,}...")
23
  config = LlamaConfig.from_pretrained(path)
 
 
 
24
  config.vocab_size = VOCAB_SIZE
25
 
26
- print(f"Loading model from {path}...")
27
- self.model = LlamaForCausalLM.from_pretrained(
28
- path,
29
- config = config,
30
- torch_dtype = torch.float16,
31
- device_map = "auto",
32
- # trust_remote_code = True,
33
- ignore_mismatched_sizes = True,
34
- )
35
- # Resize to match extended vocab (139,336 tokens)
36
- # self.model.resize_token_embeddings(len(self.tokenizer))
37
-
 
 
 
 
 
 
38
  self.model.config.pad_token_id = self.tokenizer.eos_token_id
39
  self.model.eval()
40
  print(f"Ready! Vocab: {self.model.config.vocab_size:,}")
41
 
42
- def __call__(self, data: dict) -> dict:
43
- # ── unpack request ───────────────────────────────────────────────────
44
  inputs = data.get("inputs", "")
45
  params = data.get("parameters", {})
46
  max_tokens = params.get("max_new_tokens", 400)
@@ -49,30 +57,26 @@ class EndpointHandler:
49
  rep_penalty = params.get("repetition_penalty", 1.1)
50
 
51
  if not inputs:
52
- return {"error": "No input text provided. Use the 'inputs' key."}
53
 
54
- # ── tokenise ─────────────────────────────────────────────────────────
55
  tokenized = self.tokenizer(
56
  inputs,
57
- return_tensors = "pt",
58
- truncation = True,
59
- max_length = 1024,
60
  ).to(self.model.device)
61
 
62
- # ── generate ─────────────────────────────────────────────────────────
63
  with torch.no_grad():
64
  output_ids = self.model.generate(
65
  **tokenized,
66
- max_new_tokens = max_tokens,
67
- temperature = temperature,
68
- top_p = top_p,
69
- repetition_penalty = rep_penalty,
70
- do_sample = True,
71
- pad_token_id = self.tokenizer.eos_token_id,
72
  )
73
 
74
- # ── decode (strip prompt, return only new tokens) ─────────────────────
75
  new_tokens = output_ids[0][tokenized.input_ids.shape[1]:]
76
  decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
77
-
78
- return {"generated_text": decoded.strip()}
 
21
  # the 139336-vocab checkpoint weights
22
  print(f"Patching config vocab_size to {VOCAB_SIZE:,}...")
23
  config = LlamaConfig.from_pretrained(path)
24
+
25
+ # Force correct vocab size BEFORE model is built
26
+ # so embeddings are initialized at the right size
27
  config.vocab_size = VOCAB_SIZE
28
 
29
+ print(f"Loading model weights...")
30
+ self.model = LlamaForCausalLM(config) # ← build empty model at correct size first
31
+
32
+ # Now load the checkpoint weights β€” sizes will match
33
+ import os
34
+ from safetensors.torch import load_file
35
+
36
+ weights = {}
37
+ for f in sorted(os.listdir(path)):
38
+ if f.endswith(".safetensors"):
39
+ print(f" Loading shard: {f}")
40
+ weights.update(load_file(os.path.join(path, f)))
41
+
42
+ missing, unexpected = self.model.load_state_dict(weights, strict=False)
43
+ print(f" Missing keys: {len(missing)}")
44
+ print(f" Unexpected keys: {len(unexpected)}")
45
+
46
+ self.model = self.model.to(torch.float16).to("cuda")
47
  self.model.config.pad_token_id = self.tokenizer.eos_token_id
48
  self.model.eval()
49
  print(f"Ready! Vocab: {self.model.config.vocab_size:,}")
50
 
51
+ def __call__(self, data: dict) -> dict:
 
52
  inputs = data.get("inputs", "")
53
  params = data.get("parameters", {})
54
  max_tokens = params.get("max_new_tokens", 400)
 
57
  rep_penalty = params.get("repetition_penalty", 1.1)
58
 
59
  if not inputs:
60
+ return {"error": "No input provided. Use the 'inputs' key."}
61
 
 
62
  tokenized = self.tokenizer(
63
  inputs,
64
+ return_tensors = "pt",
65
+ truncation = True,
66
+ max_length = 1024,
67
  ).to(self.model.device)
68
 
 
69
  with torch.no_grad():
70
  output_ids = self.model.generate(
71
  **tokenized,
72
+ max_new_tokens = max_tokens,
73
+ temperature = temperature,
74
+ top_p = top_p,
75
+ repetition_penalty = rep_penalty,
76
+ do_sample = True,
77
+ pad_token_id = self.tokenizer.eos_token_id,
78
  )
79
 
 
80
  new_tokens = output_ids[0][tokenized.input_ids.shape[1]:]
81
  decoded = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
82
+ return {"generated_text": decoded.strip()}