Text-to-3D
Transformers
Safetensors
English
jjohnson5253 commited on
Commit
262acca
·
1 Parent(s): 0372169

what is life

Browse files
Files changed (1) hide show
  1. handler.py +24 -80
handler.py CHANGED
@@ -98,7 +98,7 @@ class EndpointHandler:
98
 
99
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
100
  """
101
- Process inference request exactly like BrickGPT does
102
  """
103
  inputs = data.pop("inputs", data)
104
  parameters = data.pop("parameters", {})
@@ -119,70 +119,39 @@ class EndpointHandler:
119
  # Check if this is a continuation (has assistant message)
120
  has_assistant = any(msg.get("role") == "assistant" for msg in messages)
121
 
122
- # Extract the actual user instruction
123
- user_content = ""
124
- for msg in messages:
125
- if msg.get("role") == "user":
126
- content = msg["content"]
127
- if "### Input:" in content:
128
- user_content = content.split("### Input:")[-1].strip()
129
- else:
130
- user_content = content
131
- break
132
-
133
- # Create the proper instruction format (use few_shot for better results)
134
- if not has_assistant:
135
- instruction = self.create_instruction_few_shot(user_content)
136
- messages = [
137
- {"role": "system", "content": "You are a helpful assistant."},
138
- {"role": "user", "content": instruction}
139
- ]
140
-
141
- # Format input using chat template exactly like BrickGPT
142
  if has_assistant:
143
- # For continuation, use continue_final_message=True
144
- formatted_input = self.tokenizer.apply_chat_template(
145
  messages,
146
- tokenize=False,
147
- continue_final_message=True
148
  )
149
  else:
150
- # For new generation, add generation prompt
151
- formatted_input = self.tokenizer.apply_chat_template(
152
  messages,
153
- tokenize=False,
154
- add_generation_prompt=True
155
  )
156
 
157
- # FIXED: Always use sufficient tokens, ignore small requests
158
- requested_tokens = parameters.get("max_new_tokens", 100)
159
-
160
- # For continuation (single brick), use minimal tokens for just one brick
161
- if has_assistant:
162
- # A complete brick like "1x4 (16,14,1)\n" needs ~15-20 tokens
163
- # Use exactly enough for one brick to avoid generating multiple bricks
164
- actual_tokens = 15 # Reduced from 20 to prevent multiple bricks
165
- else:
166
- # For initial generation, use the requested amount or reasonable default
167
- actual_tokens = max(50, requested_tokens)
168
 
169
- # MATCH LOCAL LLM BEHAVIOR EXACTLY
 
170
  generation_params = {
171
- "max_new_tokens": actual_tokens,
172
  "temperature": parameters.get("temperature", 0.6),
173
  "top_k": parameters.get("top_k", 20),
174
  "top_p": parameters.get("top_p", 1.0),
175
- "do_sample": True, # Always True like local LLM
176
- "num_return_sequences": 1, # Match local LLM
177
  "pad_token_id": self.tokenizer.pad_token_id,
 
 
178
  "return_dict_in_generate": True,
179
- # Remove stop_strings - local LLM doesn't use them
180
  }
181
 
182
- # Tokenize input
183
- input_ids = self.tokenizer(formatted_input, return_tensors="pt").input_ids.to(self.model.device)
184
- attention_mask = torch.ones_like(input_ids)
185
-
186
  # Generate exactly like the local LLM class
187
  with torch.no_grad():
188
  output_dict = self.model.generate(
@@ -191,39 +160,14 @@ class EndpointHandler:
191
  **generation_params
192
  )
193
 
194
- # Decode exactly like the local LLM class
195
  input_length = input_ids.shape[1]
196
  result_ids = output_dict['sequences'][0][input_length:]
197
- generated_text = self.tokenizer.decode(result_ids) # No skip_special_tokens like local
198
-
199
- # Clean up the generated text
200
- generated_text = generated_text.strip()
201
-
202
- # Remove any trailing continuation artifacts
203
- if generated_text.endswith("### Output:"):
204
- generated_text = generated_text[:-11].strip()
205
-
206
- # CRITICAL FIX: Ensure single brick output for continuation
207
- if has_assistant and generated_text:
208
- # Split by lines and take only the first valid brick
209
- lines = [line.strip() for line in generated_text.split('\n') if line.strip()]
210
-
211
- if lines:
212
- first_line = lines[0]
213
- # Verify it's a complete brick format
214
- if re.match(r'\d+x\d+\s*\(\d+,\d+,\d+\)$', first_line):
215
- generated_text = first_line # NO trailing newline!
216
- else:
217
- # If first line isn't complete, try to find any complete brick
218
- for line in lines:
219
- if re.match(r'\d+x\d+\s*\(\d+,\d+,\d+\)$', line):
220
- generated_text = line # NO trailing newline!
221
- break
222
- else:
223
- # No complete brick found
224
- generated_text = ""
225
-
226
- # Extract LEGO instructions
227
  lego_instructions = self.extract_lego_instructions(generated_text)
228
 
229
  return [{
 
98
 
99
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
100
  """
101
+ Process inference request EXACTLY like local BrickGPT does
102
  """
103
  inputs = data.pop("inputs", data)
104
  parameters = data.pop("parameters", {})
 
119
  # Check if this is a continuation (has assistant message)
120
  has_assistant = any(msg.get("role") == "assistant" for msg in messages)
121
 
122
+ # Format prompt EXACTLY like local BrickGPT does
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  if has_assistant:
124
+ # For continuation, use continue_final_message=True and return tensors
125
+ prompt = self.tokenizer.apply_chat_template(
126
  messages,
127
+ continue_final_message=True,
128
+ return_tensors='pt'
129
  )
130
  else:
131
+ # For new generation, add generation prompt and return tensors
132
+ prompt = self.tokenizer.apply_chat_template(
133
  messages,
134
+ add_generation_prompt=True,
135
+ return_tensors='pt'
136
  )
137
 
138
+ # Move to device
139
+ input_ids = prompt.to(self.model.device)
140
+ attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
141
 
142
+ # Generate EXACTLY like local BrickGPT's generate_brick method
143
+ # Local BrickGPT uses max_new_tokens=10 for single brick generation
144
  generation_params = {
145
+ "max_new_tokens": 10, # EXACTLY like local BrickGPT
146
  "temperature": parameters.get("temperature", 0.6),
147
  "top_k": parameters.get("top_k", 20),
148
  "top_p": parameters.get("top_p", 1.0),
 
 
149
  "pad_token_id": self.tokenizer.pad_token_id,
150
+ "do_sample": True, # EXACTLY like local LLM
151
+ "num_return_sequences": 1, # EXACTLY like local LLM
152
  "return_dict_in_generate": True,
 
153
  }
154
 
 
 
 
 
155
  # Generate exactly like the local LLM class
156
  with torch.no_grad():
157
  output_dict = self.model.generate(
 
160
  **generation_params
161
  )
162
 
163
+ # Decode EXACTLY like local BrickGPT does
164
  input_length = input_ids.shape[1]
165
  result_ids = output_dict['sequences'][0][input_length:]
166
+
167
+ # Local BrickGPT uses skip_special_tokens=True in generate_brick methods
168
+ generated_text = self.tokenizer.decode(result_ids, skip_special_tokens=True)
169
+
170
+ # Extract LEGO instructions (same as before)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  lego_instructions = self.extract_lego_instructions(generated_text)
172
 
173
  return [{