RanjithaRuttala commited on
Commit
6d3b84e
·
verified ·
1 Parent(s): 4acd125

Upload handler (1).py

Browse files
Files changed (1) hide show
  1. handler (1).py +120 -0
handler (1).py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import json
3
+ import sys
4
+ from typing import Dict, Any, List
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+
10
+ def log(*args):
11
+ """Send logs to HuggingFace endpoint logs."""
12
+ print("[DEBUG]", *args)
13
+ sys.stdout.flush()
14
+
15
+
16
+ class EndpointHandler:
17
+ def __init__(self, path=""):
18
+ log("📌 Initializing handler...")
19
+ log("Model path:", path)
20
+
21
+ try:
22
+ self.model_id = path
23
+
24
+ # Load tokenizer
25
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
26
+ log("Tokenizer loaded.")
27
+
28
+ # Load model
29
+ self.model = AutoModelForCausalLM.from_pretrained(
30
+ path,
31
+ trust_remote_code=True,
32
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
33
+ device_map="auto",
34
+ )
35
+ log("Model loaded on device:", self.model.device)
36
+
37
+ except Exception as e:
38
+ log("❌ Error during initialization:", str(e))
39
+ log(traceback.format_exc())
40
+ raise e
41
+
42
+ log("✅ Initialization complete.")
43
+
44
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
45
+ log("----------------------------------------------------")
46
+ log("📥 Incoming Request:", json.dumps(data, indent=2))
47
+
48
+ try:
49
+ prompt = data.get("prompt") or data.get("inputs") or ""
50
+ max_tokens = data.get("max_tokens", 200)
51
+ temperature = data.get("temperature", 0.1)
52
+ stop_tokens = data.get("stop", None)
53
+
54
+ log("Prompt length:", len(prompt))
55
+ log("Max tokens:", max_tokens)
56
+ log("Temperature:", temperature)
57
+ log("Stop tokens:", stop_tokens)
58
+
59
+ # Tokenize
60
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
61
+ log("Tokenized input shape:", {k: v.shape for k, v in inputs.items()})
62
+
63
+ # Generate
64
+ outputs = self.model.generate(
65
+ **inputs,
66
+ max_new_tokens=max_tokens,
67
+ do_sample=temperature > 0,
68
+ temperature=temperature,
69
+ top_p=0.95,
70
+ pad_token_id=self.tokenizer.eos_token_id,
71
+ )
72
+
73
+ generated_full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+ output_text = generated_full[len(prompt):]
75
+ log("Raw model output:", repr(output_text[:300]))
76
+
77
+ # Apply stop tokens
78
+ if stop_tokens:
79
+ for s in stop_tokens:
80
+ if s in output_text:
81
+ output_text = output_text.split(s)[0]
82
+ log(f"Applied stop token: {s}")
83
+
84
+ output_text = output_text.strip()
85
+ log("Final output:", repr(output_text))
86
+
87
+ # Return OpenAI-compatible JSON (required by Continue)
88
+ response = {
89
+ "id": "cmpl-local",
90
+ "object": "text_completion",
91
+ "model": self.model_id,
92
+ "choices": [
93
+ {
94
+ "text": output_text,
95
+ "index": 0,
96
+ "finish_reason": "stop",
97
+ }
98
+ ],
99
+ }
100
+
101
+ log("📤 Response:", json.dumps(response, indent=2))
102
+ log("----------------------------------------------------")
103
+ return response
104
+
105
+ except Exception as e:
106
+ log("❌ Exception during inference:", str(e))
107
+ log(traceback.format_exc())
108
+
109
+ return {
110
+ "id": "cmpl-error",
111
+ "object": "text_completion",
112
+ "model": self.model_id,
113
+ "choices": [
114
+ {
115
+ "text": f"ERROR: {str(e)}",
116
+ "index": 0,
117
+ "finish_reason": "error",
118
+ }
119
+ ],
120
+ }