Commit ·
262acca
1
Parent(s): 0372169
what is life
Browse files- handler.py +24 -80
handler.py
CHANGED
|
@@ -98,7 +98,7 @@ class EndpointHandler:
|
|
| 98 |
|
| 99 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 100 |
"""
|
| 101 |
-
Process inference request
|
| 102 |
"""
|
| 103 |
inputs = data.pop("inputs", data)
|
| 104 |
parameters = data.pop("parameters", {})
|
|
@@ -119,70 +119,39 @@ class EndpointHandler:
|
|
| 119 |
# Check if this is a continuation (has assistant message)
|
| 120 |
has_assistant = any(msg.get("role") == "assistant" for msg in messages)
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
user_content = ""
|
| 124 |
-
for msg in messages:
|
| 125 |
-
if msg.get("role") == "user":
|
| 126 |
-
content = msg["content"]
|
| 127 |
-
if "### Input:" in content:
|
| 128 |
-
user_content = content.split("### Input:")[-1].strip()
|
| 129 |
-
else:
|
| 130 |
-
user_content = content
|
| 131 |
-
break
|
| 132 |
-
|
| 133 |
-
# Create the proper instruction format (use few_shot for better results)
|
| 134 |
-
if not has_assistant:
|
| 135 |
-
instruction = self.create_instruction_few_shot(user_content)
|
| 136 |
-
messages = [
|
| 137 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
| 138 |
-
{"role": "user", "content": instruction}
|
| 139 |
-
]
|
| 140 |
-
|
| 141 |
-
# Format input using chat template exactly like BrickGPT
|
| 142 |
if has_assistant:
|
| 143 |
-
# For continuation, use continue_final_message=True
|
| 144 |
-
|
| 145 |
messages,
|
| 146 |
-
|
| 147 |
-
|
| 148 |
)
|
| 149 |
else:
|
| 150 |
-
# For new generation, add generation prompt
|
| 151 |
-
|
| 152 |
messages,
|
| 153 |
-
|
| 154 |
-
|
| 155 |
)
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
# For continuation (single brick), use minimal tokens for just one brick
|
| 161 |
-
if has_assistant:
|
| 162 |
-
# A complete brick like "1x4 (16,14,1)\n" needs ~15-20 tokens
|
| 163 |
-
# Use exactly enough for one brick to avoid generating multiple bricks
|
| 164 |
-
actual_tokens = 15 # Reduced from 20 to prevent multiple bricks
|
| 165 |
-
else:
|
| 166 |
-
# For initial generation, use the requested amount or reasonable default
|
| 167 |
-
actual_tokens = max(50, requested_tokens)
|
| 168 |
|
| 169 |
-
#
|
|
|
|
| 170 |
generation_params = {
|
| 171 |
-
"max_new_tokens":
|
| 172 |
"temperature": parameters.get("temperature", 0.6),
|
| 173 |
"top_k": parameters.get("top_k", 20),
|
| 174 |
"top_p": parameters.get("top_p", 1.0),
|
| 175 |
-
"do_sample": True, # Always True like local LLM
|
| 176 |
-
"num_return_sequences": 1, # Match local LLM
|
| 177 |
"pad_token_id": self.tokenizer.pad_token_id,
|
|
|
|
|
|
|
| 178 |
"return_dict_in_generate": True,
|
| 179 |
-
# Remove stop_strings - local LLM doesn't use them
|
| 180 |
}
|
| 181 |
|
| 182 |
-
# Tokenize input
|
| 183 |
-
input_ids = self.tokenizer(formatted_input, return_tensors="pt").input_ids.to(self.model.device)
|
| 184 |
-
attention_mask = torch.ones_like(input_ids)
|
| 185 |
-
|
| 186 |
# Generate exactly like the local LLM class
|
| 187 |
with torch.no_grad():
|
| 188 |
output_dict = self.model.generate(
|
|
@@ -191,39 +160,14 @@ class EndpointHandler:
|
|
| 191 |
**generation_params
|
| 192 |
)
|
| 193 |
|
| 194 |
-
# Decode
|
| 195 |
input_length = input_ids.shape[1]
|
| 196 |
result_ids = output_dict['sequences'][0][input_length:]
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
# Remove any trailing continuation artifacts
|
| 203 |
-
if generated_text.endswith("### Output:"):
|
| 204 |
-
generated_text = generated_text[:-11].strip()
|
| 205 |
-
|
| 206 |
-
# CRITICAL FIX: Ensure single brick output for continuation
|
| 207 |
-
if has_assistant and generated_text:
|
| 208 |
-
# Split by lines and take only the first valid brick
|
| 209 |
-
lines = [line.strip() for line in generated_text.split('\n') if line.strip()]
|
| 210 |
-
|
| 211 |
-
if lines:
|
| 212 |
-
first_line = lines[0]
|
| 213 |
-
# Verify it's a complete brick format
|
| 214 |
-
if re.match(r'\d+x\d+\s*\(\d+,\d+,\d+\)$', first_line):
|
| 215 |
-
generated_text = first_line # NO trailing newline!
|
| 216 |
-
else:
|
| 217 |
-
# If first line isn't complete, try to find any complete brick
|
| 218 |
-
for line in lines:
|
| 219 |
-
if re.match(r'\d+x\d+\s*\(\d+,\d+,\d+\)$', line):
|
| 220 |
-
generated_text = line # NO trailing newline!
|
| 221 |
-
break
|
| 222 |
-
else:
|
| 223 |
-
# No complete brick found
|
| 224 |
-
generated_text = ""
|
| 225 |
-
|
| 226 |
-
# Extract LEGO instructions
|
| 227 |
lego_instructions = self.extract_lego_instructions(generated_text)
|
| 228 |
|
| 229 |
return [{
|
|
|
|
| 98 |
|
| 99 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 100 |
"""
|
| 101 |
+
Process inference request EXACTLY like local BrickGPT does
|
| 102 |
"""
|
| 103 |
inputs = data.pop("inputs", data)
|
| 104 |
parameters = data.pop("parameters", {})
|
|
|
|
| 119 |
# Check if this is a continuation (has assistant message)
|
| 120 |
has_assistant = any(msg.get("role") == "assistant" for msg in messages)
|
| 121 |
|
| 122 |
+
# Format prompt EXACTLY like local BrickGPT does
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
if has_assistant:
|
| 124 |
+
# For continuation, use continue_final_message=True and return tensors
|
| 125 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 126 |
messages,
|
| 127 |
+
continue_final_message=True,
|
| 128 |
+
return_tensors='pt'
|
| 129 |
)
|
| 130 |
else:
|
| 131 |
+
# For new generation, add generation prompt and return tensors
|
| 132 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 133 |
messages,
|
| 134 |
+
add_generation_prompt=True,
|
| 135 |
+
return_tensors='pt'
|
| 136 |
)
|
| 137 |
|
| 138 |
+
# Move to device
|
| 139 |
+
input_ids = prompt.to(self.model.device)
|
| 140 |
+
attention_mask = torch.ones_like(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
# Generate EXACTLY like local BrickGPT's generate_brick method
|
| 143 |
+
# Local BrickGPT uses max_new_tokens=10 for single brick generation
|
| 144 |
generation_params = {
|
| 145 |
+
"max_new_tokens": 10, # EXACTLY like local BrickGPT
|
| 146 |
"temperature": parameters.get("temperature", 0.6),
|
| 147 |
"top_k": parameters.get("top_k", 20),
|
| 148 |
"top_p": parameters.get("top_p", 1.0),
|
|
|
|
|
|
|
| 149 |
"pad_token_id": self.tokenizer.pad_token_id,
|
| 150 |
+
"do_sample": True, # EXACTLY like local LLM
|
| 151 |
+
"num_return_sequences": 1, # EXACTLY like local LLM
|
| 152 |
"return_dict_in_generate": True,
|
|
|
|
| 153 |
}
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Generate exactly like the local LLM class
|
| 156 |
with torch.no_grad():
|
| 157 |
output_dict = self.model.generate(
|
|
|
|
| 160 |
**generation_params
|
| 161 |
)
|
| 162 |
|
| 163 |
+
# Decode EXACTLY like local BrickGPT does
|
| 164 |
input_length = input_ids.shape[1]
|
| 165 |
result_ids = output_dict['sequences'][0][input_length:]
|
| 166 |
+
|
| 167 |
+
# Local BrickGPT uses skip_special_tokens=True in generate_brick methods
|
| 168 |
+
generated_text = self.tokenizer.decode(result_ids, skip_special_tokens=True)
|
| 169 |
+
|
| 170 |
+
# Extract LEGO instructions (same as before)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
lego_instructions = self.extract_lego_instructions(generated_text)
|
| 172 |
|
| 173 |
return [{
|