YFolla commited on
Commit
9fd337d
·
verified ·
1 Parent(s): 5ab33b7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -41
handler.py CHANGED
@@ -25,66 +25,71 @@ class EndpointHandler():
25
  # Optional: Explicitly set pad token if needed
26
  # if self.tokenizer.pad_token is None:
27
  # self.tokenizer.pad_token = self.tokenizer.eos_token
28
-
29
- # Create a text-generation pipeline for easier handling
30
- self.pipeline = pipeline(
31
- "text-generation",
32
- model=self.model,
33
- tokenizer=self.tokenizer,
34
- # device_map="auto" # device_map should be handled by model loading
35
- )
36
  print("Handler initialized: Model and tokenizer loaded.")
37
 
38
 
39
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
40
  """
41
- Handles the inference request.
42
- 'data' is a dictionary containing the request payload.
43
- We expect 'inputs' to hold the prompt text.
44
- Optional 'parameters' can control generation settings.
45
  """
46
  try:
47
  # Extract inputs and parameters
48
- inputs = data.pop("inputs", None)
49
  parameters = data.pop("parameters", {})
50
 
51
- if inputs is None:
52
  return [{"error": "Missing 'inputs' key in request data."}]
53
 
54
- # --- Handle different input types ---
55
- if isinstance(inputs, str):
56
- processed_inputs = [inputs]
57
- elif isinstance(inputs, list) and all(isinstance(i, str) for i in inputs):
58
- processed_inputs = inputs # Already a list of strings
59
- else:
60
- return [{"error": "Invalid 'inputs' format. Must be a string or a list of strings."}]
61
- # --- End input handling ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Set generation parameters (use pipeline defaults + overrides)
64
- # Sensible defaults if not provided by user
65
- parameters.setdefault("max_new_tokens", 64)
66
- parameters.setdefault("temperature", 1.0)
67
- parameters.setdefault("top_p", 0.95)
68
- parameters.setdefault("top_k", 64)
69
- # Ensure pipeline doesn't add EOS if user controls max_new_tokens precisely
70
- # parameters.setdefault("return_full_text", False) # Often useful
71
 
72
- # --- ADD THIS LINE ---
73
- parameters["return_full_text"] = False
74
- # ---------------------
 
75
 
76
- print(f"Received inputs: {processed_inputs}")
77
- print(f"Using parameters: {parameters}")
 
 
 
 
78
 
79
- # Run inference through the pipeline
80
- results = self.pipeline(processed_inputs, **parameters)
81
- print(f"Pipeline results: {results}")
82
 
83
- # Return the results directly (pipeline usually formats correctly)
84
- return results
85
 
86
  except Exception as e:
87
- # More detailed error logging
88
  import traceback
89
  print(f"Error during inference: {e}")
90
  print(traceback.format_exc())
 
25
  # Optional: Explicitly set pad token if needed
26
  # if self.tokenizer.pad_token is None:
27
  # self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
 
 
 
28
  print("Handler initialized: Model and tokenizer loaded.")
29
 
30
 
31
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
32
  """
33
+ Handles the inference request using manual generation.
 
 
 
34
  """
35
  try:
36
  # Extract inputs and parameters
37
+ inputs_text = data.pop("inputs", None)
38
  parameters = data.pop("parameters", {})
39
 
40
+ if inputs_text is None:
41
  return [{"error": "Missing 'inputs' key in request data."}]
42
 
43
+ # Basic input validation
44
+ if not isinstance(inputs_text, str):
45
+ return [{"error": "Invalid 'inputs' format. Must be a single string for this handler."}]
46
+
47
+ # Set generation parameters
48
+ params = {
49
+ "max_new_tokens": 64,
50
+ "temperature": 1.0,
51
+ "top_p": 0.95,
52
+ "top_k": 64,
53
+ "do_sample": True, # Explicitly enable sampling
54
+ "pad_token_id": self.tokenizer.eos_token_id # Use EOS for padding
55
+ }
56
+ # Update with user-provided parameters
57
+ params.update(parameters)
58
+
59
+ print(f"Received input: '{inputs_text}'")
60
+ print(f"Using parameters: {params}")
61
+
62
+ # Manually tokenize
63
+ # Important: Add generation prompt structure if needed by the model/tokenizer chat template!
64
+ # Assuming the tokenizer's chat template handles adding the prompt correctly when needed.
65
+ # If not, you might need manual formatting here before tokenizing.
66
+ # Let's try applying the chat template explicitly for robustness:
67
+ messages = [{"role": "user", "content": inputs_text}]
68
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
69
+
70
+ print(f"Formatted prompt: '{prompt}'")
71
 
72
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
 
73
 
74
+ # Manually generate
75
+ # Use torch.no_grad() for efficiency during inference
76
+ with torch.no_grad():
77
+ outputs = self.model.generate(**inputs, **params)
78
 
79
+ # Decode the output
80
+ # outputs[0] contains the full sequence (prompt + generation)
81
+ # We need to decode only the generated part
82
+ input_length = inputs.input_ids.shape[1]
83
+ generated_ids = outputs[0][input_length:]
84
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
85
 
86
+ print(f"Generated IDs length: {len(generated_ids)}")
87
+ print(f"Decoded generated text: '{generated_text}'")
 
88
 
89
+ # Return the results
90
+ return [{"generated_text": generated_text}]
91
 
92
  except Exception as e:
 
93
  import traceback
94
  print(f"Error during inference: {e}")
95
  print(traceback.format_exc())