Update handler.py
Browse files- handler.py +46 -41
handler.py
CHANGED
|
@@ -25,66 +25,71 @@ class EndpointHandler():
|
|
| 25 |
# Optional: Explicitly set pad token if needed
|
| 26 |
# if self.tokenizer.pad_token is None:
|
| 27 |
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 28 |
-
|
| 29 |
-
# Create a text-generation pipeline for easier handling
|
| 30 |
-
self.pipeline = pipeline(
|
| 31 |
-
"text-generation",
|
| 32 |
-
model=self.model,
|
| 33 |
-
tokenizer=self.tokenizer,
|
| 34 |
-
# device_map="auto" # device_map should be handled by model loading
|
| 35 |
-
)
|
| 36 |
print("Handler initialized: Model and tokenizer loaded.")
|
| 37 |
|
| 38 |
|
| 39 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 40 |
"""
|
| 41 |
-
Handles the inference request.
|
| 42 |
-
'data' is a dictionary containing the request payload.
|
| 43 |
-
We expect 'inputs' to hold the prompt text.
|
| 44 |
-
Optional 'parameters' can control generation settings.
|
| 45 |
"""
|
| 46 |
try:
|
| 47 |
# Extract inputs and parameters
|
| 48 |
-
|
| 49 |
parameters = data.pop("parameters", {})
|
| 50 |
|
| 51 |
-
if
|
| 52 |
return [{"error": "Missing 'inputs' key in request data."}]
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
if isinstance(
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
# Sensible defaults if not provided by user
|
| 65 |
-
parameters.setdefault("max_new_tokens", 64)
|
| 66 |
-
parameters.setdefault("temperature", 1.0)
|
| 67 |
-
parameters.setdefault("top_p", 0.95)
|
| 68 |
-
parameters.setdefault("top_k", 64)
|
| 69 |
-
# Ensure pipeline doesn't add EOS if user controls max_new_tokens precisely
|
| 70 |
-
# parameters.setdefault("return_full_text", False) # Often useful
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
print(f"Pipeline results: {results}")
|
| 82 |
|
| 83 |
-
# Return the results
|
| 84 |
-
return
|
| 85 |
|
| 86 |
except Exception as e:
|
| 87 |
-
# More detailed error logging
|
| 88 |
import traceback
|
| 89 |
print(f"Error during inference: {e}")
|
| 90 |
print(traceback.format_exc())
|
|
|
|
| 25 |
# Optional: Explicitly set pad token if needed
|
| 26 |
# if self.tokenizer.pad_token is None:
|
| 27 |
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
print("Handler initialized: Model and tokenizer loaded.")
|
| 29 |
|
| 30 |
|
| 31 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 32 |
"""
|
| 33 |
+
Handles the inference request using manual generation.
|
|
|
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
try:
|
| 36 |
# Extract inputs and parameters
|
| 37 |
+
inputs_text = data.pop("inputs", None)
|
| 38 |
parameters = data.pop("parameters", {})
|
| 39 |
|
| 40 |
+
if inputs_text is None:
|
| 41 |
return [{"error": "Missing 'inputs' key in request data."}]
|
| 42 |
|
| 43 |
+
# Basic input validation
|
| 44 |
+
if not isinstance(inputs_text, str):
|
| 45 |
+
return [{"error": "Invalid 'inputs' format. Must be a single string for this handler."}]
|
| 46 |
+
|
| 47 |
+
# Set generation parameters
|
| 48 |
+
params = {
|
| 49 |
+
"max_new_tokens": 64,
|
| 50 |
+
"temperature": 1.0,
|
| 51 |
+
"top_p": 0.95,
|
| 52 |
+
"top_k": 64,
|
| 53 |
+
"do_sample": True, # Explicitly enable sampling
|
| 54 |
+
"pad_token_id": self.tokenizer.eos_token_id # Use EOS for padding
|
| 55 |
+
}
|
| 56 |
+
# Update with user-provided parameters
|
| 57 |
+
params.update(parameters)
|
| 58 |
+
|
| 59 |
+
print(f"Received input: '{inputs_text}'")
|
| 60 |
+
print(f"Using parameters: {params}")
|
| 61 |
+
|
| 62 |
+
# Manually tokenize
|
| 63 |
+
# Important: Add generation prompt structure if needed by the model/tokenizer chat template!
|
| 64 |
+
# Assuming the tokenizer's chat template handles adding the prompt correctly when needed.
|
| 65 |
+
# If not, you might need manual formatting here before tokenizing.
|
| 66 |
+
# Let's try applying the chat template explicitly for robustness:
|
| 67 |
+
messages = [{"role": "user", "content": inputs_text}]
|
| 68 |
+
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 69 |
+
|
| 70 |
+
print(f"Formatted prompt: '{prompt}'")
|
| 71 |
|
| 72 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
# Manually generate
|
| 75 |
+
# Use torch.no_grad() for efficiency during inference
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
outputs = self.model.generate(**inputs, **params)
|
| 78 |
|
| 79 |
+
# Decode the output
|
| 80 |
+
# outputs[0] contains the full sequence (prompt + generation)
|
| 81 |
+
# We need to decode only the generated part
|
| 82 |
+
input_length = inputs.input_ids.shape[1]
|
| 83 |
+
generated_ids = outputs[0][input_length:]
|
| 84 |
+
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 85 |
|
| 86 |
+
print(f"Generated IDs length: {len(generated_ids)}")
|
| 87 |
+
print(f"Decoded generated text: '{generated_text}'")
|
|
|
|
| 88 |
|
| 89 |
+
# Return the results
|
| 90 |
+
return [{"generated_text": generated_text}]
|
| 91 |
|
| 92 |
except Exception as e:
|
|
|
|
| 93 |
import traceback
|
| 94 |
print(f"Error during inference: {e}")
|
| 95 |
print(traceback.format_exc())
|