girish00 commited on
Commit
4c195e0
·
verified ·
1 Parent(s): 007a7c7

add structured endpoint handler

Browse files
Files changed (1) hide show
  1. handler.py +108 -0
handler.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ import time
3
+
4
+ import torch
5
+ from peft import PeftConfig, PeftModel
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ from infer_local import (
9
+ build_instruction_prompt,
10
+ build_structured_result,
11
+ has_adapter_weights,
12
+ has_full_model_weights,
13
+ )
14
+
15
+
16
+ class EndpointHandler:
17
+ def __init__(self, path: str = ""):
18
+ self.path = path or "."
19
+ adapter_config_path = f"{self.path}/adapter_config.json"
20
+ adapter_weights_present = has_adapter_weights(self.path)
21
+ full_model_weights_present = has_full_model_weights(self.path)
22
+
23
+ if adapter_weights_present:
24
+ peft_config = PeftConfig.from_pretrained(self.path)
25
+ base_model_name = peft_config.base_model_name_or_path
26
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
27
+ base_model = AutoModelForCausalLM.from_pretrained(
28
+ base_model_name,
29
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
+ )
31
+ self.model = PeftModel.from_pretrained(base_model, self.path)
32
+ elif full_model_weights_present:
33
+ self.tokenizer = AutoTokenizer.from_pretrained(self.path)
34
+ self.model = AutoModelForCausalLM.from_pretrained(
35
+ self.path,
36
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
37
+ )
38
+ else:
39
+ raise RuntimeError(
40
+ f"No adapter or full-model weights found at endpoint model path: {self.path}"
41
+ )
42
+
43
+ if self.tokenizer.pad_token is None:
44
+ self.tokenizer.pad_token = self.tokenizer.eos_token
45
+
46
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ self.model.to(self.device)
48
+ self.model.eval()
49
+ self.model.generation_config.do_sample = False
50
+ self.model.generation_config.temperature = 1.0
51
+ self.model.generation_config.top_p = 1.0
52
+ self.model.generation_config.top_k = 50
53
+
54
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
55
+ user_prompt = data.get("inputs", data.get("prompt", ""))
56
+ if isinstance(user_prompt, list):
57
+ user_prompt = user_prompt[0] if user_prompt else ""
58
+ user_prompt = str(user_prompt).strip()
59
+ if not user_prompt:
60
+ return {
61
+ "error": "Missing prompt. Send {'inputs': 'your coding prompt'}."
62
+ }
63
+
64
+ parameters = data.get("parameters", {}) or {}
65
+ max_new_tokens = int(parameters.get("max_new_tokens", 320))
66
+ do_sample = bool(parameters.get("do_sample", False))
67
+
68
+ prompt_text = build_instruction_prompt(user_prompt)
69
+ inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
70
+
71
+ generation_kwargs = {
72
+ "max_new_tokens": max_new_tokens,
73
+ "output_scores": True,
74
+ "return_dict_in_generate": True,
75
+ "do_sample": do_sample,
76
+ "pad_token_id": self.tokenizer.eos_token_id,
77
+ }
78
+ if do_sample:
79
+ generation_kwargs["temperature"] = float(parameters.get("temperature", 0.25))
80
+ generation_kwargs["top_p"] = float(parameters.get("top_p", 0.9))
81
+
82
+ started_at = time.perf_counter()
83
+ with torch.no_grad():
84
+ generated = self.model.generate(**inputs, **generation_kwargs)
85
+ latency_ms = int((time.perf_counter() - started_at) * 1000)
86
+
87
+ output_ids = generated.sequences[0]
88
+ prompt_len = inputs["input_ids"].shape[1]
89
+ generated_ids = output_ids[prompt_len:].tolist()
90
+ generated_text = self.tokenizer.decode(
91
+ generated_ids,
92
+ skip_special_tokens=True,
93
+ ).strip()
94
+
95
+ token_confidences = []
96
+ if generated.scores:
97
+ for token_id, score_tensor in zip(generated_ids, generated.scores):
98
+ probs = torch.softmax(score_tensor[0], dim=-1)
99
+ token_confidences.append(float(probs[token_id].item()))
100
+
101
+ return build_structured_result(
102
+ user_prompt,
103
+ generated_text,
104
+ latency_ms,
105
+ tokenizer=self.tokenizer,
106
+ generated_ids=generated_ids,
107
+ token_confidences=token_confidences,
108
+ )