Text-to-3D
Transformers
Safetensors
English
jjohnson5253 commited on
Commit
4c4a40c
·
1 Parent(s): 0be15f2

add fewshots

Browse files
Files changed (1) hide show
  1. handler.py +111 -46
handler.py CHANGED
@@ -2,6 +2,8 @@ from typing import Dict, List, Any
2
  import torch
3
  import re
4
  import os
 
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
 
@@ -25,6 +27,56 @@ class EndpointHandler:
25
  # Set pad token if not exists
26
  if self.tokenizer.pad_token is None:
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def extract_lego_instructions(self, text: str) -> List[str]:
30
  """Extract LEGO brick instructions from generated text"""
@@ -46,89 +98,102 @@ class EndpointHandler:
46
 
47
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
48
  """
49
- Process inference request
50
- data args:
51
- inputs (:obj:`str` or :obj:`Dict`): The input text or messages
52
- parameters (:obj:`Dict`, optional): Parameters for generation
53
  """
54
  inputs = data.pop("inputs", data)
55
  parameters = data.pop("parameters", {})
56
 
57
  # Handle different input formats that BrickGPT sends
58
  if isinstance(inputs, dict) and "messages" in inputs:
59
- # BrickGPT format: {"messages": [{"role": "system", ...}, {"role": "user", ...}]}
60
  messages = inputs["messages"]
61
  elif isinstance(inputs, list):
62
- # Direct messages array: [{"role": "system", ...}, {"role": "user", ...}]
63
  messages = inputs
64
  elif isinstance(inputs, str):
65
- # Plain string input - create default messages
66
  messages = [
67
  {"role": "system", "content": "You are a helpful assistant."},
68
  {"role": "user", "content": inputs}
69
  ]
70
  else:
71
- # Fallback
72
  messages = [{"role": "user", "content": str(inputs)}]
73
 
74
  # Check if this is a continuation (has assistant message)
75
  has_assistant = any(msg.get("role") == "assistant" for msg in messages)
76
 
77
- # Format input using chat template
78
- if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
79
- if has_assistant:
80
- # For continuation, use continue_final_message=True
81
- formatted_input = self.tokenizer.apply_chat_template(
82
- messages,
83
- tokenize=False,
84
- continue_final_message=True
85
- )
86
- else:
87
- # For new generation, add generation prompt
88
- formatted_input = self.tokenizer.apply_chat_template(
89
- messages,
90
- tokenize=False,
91
- add_generation_prompt=True
92
- )
 
 
 
 
 
 
 
 
 
 
 
93
  else:
94
- # Fallback to direct input (last user message)
95
- user_messages = [msg["content"] for msg in messages if msg.get("role") == "user"]
96
- formatted_input = user_messages[-1] if user_messages else str(inputs)
 
 
 
97
 
98
- # Default parameters optimized for BrickGPT
99
- # BrickGPT generates SHORT responses (one brick at a time), not long descriptions
100
- default_max_tokens = 20 if has_assistant else 200 # Much shorter for continuation
101
  generation_params = {
102
  "max_new_tokens": parameters.get("max_new_tokens", default_max_tokens),
103
- "temperature": parameters.get("temperature", 0.6),
104
- "top_k": parameters.get("top_k", 20),
105
- "top_p": parameters.get("top_p", 1.0),
106
  "do_sample": parameters.get("do_sample", True),
107
  "pad_token_id": self.tokenizer.pad_token_id,
 
108
  }
109
 
110
- # Add stop tokens if provided
111
- stop_tokens = parameters.get("stop", [])
112
- if stop_tokens:
113
- generation_params["stop_strings"] = stop_tokens
114
-
115
  # Tokenize input
116
  input_ids = self.tokenizer(formatted_input, return_tensors="pt").input_ids.to(self.model.device)
 
117
 
118
- # Generate
119
  with torch.no_grad():
120
- outputs = self.model.generate(input_ids, **generation_params)
 
 
 
 
121
 
122
- # Decode output
123
- generated_text = self.tokenizer.decode(
124
- outputs[0][input_ids.shape[1]:],
125
- skip_special_tokens=True
126
- )
 
 
 
 
 
 
127
 
128
  # Extract LEGO instructions
129
  lego_instructions = self.extract_lego_instructions(generated_text)
130
 
131
  return [{
132
  "generated_text": generated_text,
133
- "lego_instructions": lego_instructions # Fixed the field name
134
  }]
 
2
  import torch
3
  import re
4
  import os
5
+ import json
6
+ from pathlib import Path
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
 
 
27
  # Set pad token if not exists
28
  if self.tokenizer.pad_token is None:
29
  self.tokenizer.pad_token = self.tokenizer.eos_token
30
+
31
+ # Load few-shot examples (same as in BrickGPT)
32
+ self.few_shot_examples = [
33
+ {
34
+ "caption": "Bed with rectangular base and straight headboard.",
35
+ "bricks": "1x2 (13,18,0)\n1x2 (13,2,0)\n2x2 (0,18,0)\n2x2 (0,2,0)\n2x6 (12,14,1)\n2x6 (12,8,1)"
36
+ },
37
+ {
38
+ "caption": "Simple chair with straight backrest and square seat.",
39
+ "bricks": "2x2 (5,18,0)\n2x2 (5,13,0)\n2x2 (0,18,0)\n2x2 (0,13,0)\n2x2 (5,18,1)\n2x2 (5,13,1)"
40
+ },
41
+ {
42
+ "caption": "Square table with four legs and a flat surface.",
43
+ "bricks": "2x2 (16,18,0)\n2x2 (16,8,0)\n1x1 (15,18,0)\n1x1 (15,9,0)\n2x2 (0,18,0)\n2x2 (0,8,0)"
44
+ }
45
+ ]
46
+
47
+ def create_instruction(self, caption: str) -> str:
48
+ """Create instruction exactly like BrickGPT does"""
49
+ instruction = ('Create a LEGO model of the input. Format your response as a list of bricks: '
50
+ '<brick dimensions> <brick position>, where the brick position is (x,y,z).\n'
51
+ 'Allowed brick dimensions are 2x4, 4x2, 2x6, 6x2, 1x2, 2x1, 1x4, 4x1, 1x6, 6x1, 1x8, 8x1, 1x1, 2x2.\n'
52
+ 'All bricks are 1 unit tall.\n\n'
53
+ '### Input:\n'
54
+ f'{caption}')
55
+ return instruction
56
+
57
+ def create_instruction_few_shot(self, caption: str) -> str:
58
+ """Create few-shot instruction exactly like BrickGPT does"""
59
+ base_instruction = self.create_instruction(caption)
60
+ zero_shot_instructions = (
61
+ 'Each line of your output should be a LEGO brick in the format `<brick dimensions> <brick position>`. For example:\n'
62
+ '2x4 (2,1,0)\n'
63
+ 'DO NOT output any other text. Only output LEGO bricks. The first brick should have a z-coordinate of 0.'
64
+ )
65
+
66
+ example_prompt = 'Here are some example LEGO models:'
67
+ example_instructions = '\n\n'.join(self._create_example_instruction(x) for x in self.few_shot_examples)
68
+ few_shot_instructions = (
69
+ 'Do NOT copy the examples, but create your own LEGO model for the following input.\n\n'
70
+ '### Input:\n'
71
+ f'{caption}\n\n'
72
+ '### Output:\n'
73
+ )
74
+
75
+ return '\n\n'.join([base_instruction, zero_shot_instructions, example_prompt,
76
+ example_instructions, few_shot_instructions])
77
+
78
+ def _create_example_instruction(self, x: dict) -> str:
79
+ return f'### Input:\n{x["caption"]}\n\n### Output:\n{x["bricks"]}'
80
 
81
  def extract_lego_instructions(self, text: str) -> List[str]:
82
  """Extract LEGO brick instructions from generated text"""
 
98
 
99
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
100
  """
101
+ Process inference request exactly like BrickGPT does
 
 
 
102
  """
103
  inputs = data.pop("inputs", data)
104
  parameters = data.pop("parameters", {})
105
 
106
  # Handle different input formats that BrickGPT sends
107
  if isinstance(inputs, dict) and "messages" in inputs:
 
108
  messages = inputs["messages"]
109
  elif isinstance(inputs, list):
 
110
  messages = inputs
111
  elif isinstance(inputs, str):
 
112
  messages = [
113
  {"role": "system", "content": "You are a helpful assistant."},
114
  {"role": "user", "content": inputs}
115
  ]
116
  else:
 
117
  messages = [{"role": "user", "content": str(inputs)}]
118
 
119
  # Check if this is a continuation (has assistant message)
120
  has_assistant = any(msg.get("role") == "assistant" for msg in messages)
121
 
122
+ # Extract the actual user instruction
123
+ user_content = ""
124
+ for msg in messages:
125
+ if msg.get("role") == "user":
126
+ content = msg["content"]
127
+ if "### Input:" in content:
128
+ user_content = content.split("### Input:")[-1].strip()
129
+ else:
130
+ user_content = content
131
+ break
132
+
133
+ # Create the proper instruction format (use few_shot for better results)
134
+ if not has_assistant:
135
+ instruction = self.create_instruction_few_shot(user_content)
136
+ messages = [
137
+ {"role": "system", "content": "You are a helpful assistant."},
138
+ {"role": "user", "content": instruction}
139
+ ]
140
+
141
+ # Format input using chat template exactly like BrickGPT
142
+ if has_assistant:
143
+ # For continuation, use continue_final_message=True
144
+ formatted_input = self.tokenizer.apply_chat_template(
145
+ messages,
146
+ tokenize=False,
147
+ continue_final_message=True
148
+ )
149
  else:
150
+ # For new generation, add generation prompt
151
+ formatted_input = self.tokenizer.apply_chat_template(
152
+ messages,
153
+ tokenize=False,
154
+ add_generation_prompt=True
155
+ )
156
 
157
+ # Generation parameters that match BrickGPT's approach
158
+ default_max_tokens = 15 if has_assistant else 50 # BrickGPT generates one brick at a time
 
159
  generation_params = {
160
  "max_new_tokens": parameters.get("max_new_tokens", default_max_tokens),
161
+ "temperature": parameters.get("temperature", 0.6), # BrickGPT default
162
+ "top_k": parameters.get("top_k", 20), # BrickGPT default
163
+ "top_p": parameters.get("top_p", 1.0), # BrickGPT default
164
  "do_sample": parameters.get("do_sample", True),
165
  "pad_token_id": self.tokenizer.pad_token_id,
166
+ "return_dict_in_generate": True
167
  }
168
 
 
 
 
 
 
169
  # Tokenize input
170
  input_ids = self.tokenizer(formatted_input, return_tensors="pt").input_ids.to(self.model.device)
171
+ attention_mask = torch.ones_like(input_ids)
172
 
173
+ # Generate exactly like the local LLM class
174
  with torch.no_grad():
175
+ output_dict = self.model.generate(
176
+ input_ids,
177
+ attention_mask=attention_mask,
178
+ **generation_params
179
+ )
180
 
181
+ # Decode exactly like the local LLM class
182
+ input_length = input_ids.shape[1]
183
+ result_ids = output_dict['sequences'][0][input_length:]
184
+ generated_text = self.tokenizer.decode(result_ids, skip_special_tokens=True)
185
+
186
+ # Clean up the generated text
187
+ generated_text = generated_text.strip()
188
+
189
+ # Remove any trailing continuation artifacts
190
+ if generated_text.endswith("### Output:"):
191
+ generated_text = generated_text[:-11].strip()
192
 
193
  # Extract LEGO instructions
194
  lego_instructions = self.extract_lego_instructions(generated_text)
195
 
196
  return [{
197
  "generated_text": generated_text,
198
+ "lego_instructions": lego_instructions
199
  }]