Text-to-3D
Transformers
Safetensors
English
jjohnson5253 commited on
Commit
73a8f69
·
1 Parent(s): 262acca

fix me up dady

Browse files
Files changed (1) hide show
  1. handler.py +16 -96
handler.py CHANGED
@@ -1,9 +1,6 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- import re
4
  import os
5
- import json
6
- from pathlib import Path
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
 
@@ -27,108 +24,37 @@ class EndpointHandler:
27
  # Set pad token if not exists
28
  if self.tokenizer.pad_token is None:
29
  self.tokenizer.pad_token = self.tokenizer.eos_token
30
-
31
- # Load few-shot examples (same as in BrickGPT)
32
- self.few_shot_examples = [
33
- {
34
- "caption": "Bed with rectangular base and straight headboard.",
35
- "bricks": "1x2 (13,18,0)\n1x2 (13,2,0)\n2x2 (0,18,0)\n2x2 (0,2,0)\n2x6 (12,14,1)\n2x6 (12,8,1)"
36
- },
37
- {
38
- "caption": "Simple chair with straight backrest and square seat.",
39
- "bricks": "2x2 (5,18,0)\n2x2 (5,13,0)\n2x2 (0,18,0)\n2x2 (0,13,0)\n2x2 (5,18,1)\n2x2 (5,13,1)"
40
- },
41
- {
42
- "caption": "Square table with four legs and a flat surface.",
43
- "bricks": "2x2 (16,18,0)\n2x2 (16,8,0)\n1x1 (15,18,0)\n1x1 (15,9,0)\n2x2 (0,18,0)\n2x2 (0,8,0)"
44
- }
45
- ]
46
-
47
- def create_instruction(self, caption: str) -> str:
48
- """Create instruction exactly like BrickGPT does"""
49
- instruction = ('Create a LEGO model of the input. Format your response as a list of bricks: '
50
- '<brick dimensions> <brick position>, where the brick position is (x,y,z).\n'
51
- 'Allowed brick dimensions are 2x4, 4x2, 2x6, 6x2, 1x2, 2x1, 1x4, 4x1, 1x6, 6x1, 1x8, 8x1, 1x1, 2x2.\n'
52
- 'All bricks are 1 unit tall.\n\n'
53
- '### Input:\n'
54
- f'{caption}')
55
- return instruction
56
-
57
- def create_instruction_few_shot(self, caption: str) -> str:
58
- """Create few-shot instruction exactly like BrickGPT does"""
59
- base_instruction = self.create_instruction(caption)
60
- zero_shot_instructions = (
61
- 'Each line of your output should be a LEGO brick in the format `<brick dimensions> <brick position>`. For example:\n'
62
- '2x4 (2,1,0)\n'
63
- 'DO NOT output any other text. Only output LEGO bricks. The first brick should have a z-coordinate of 0.'
64
- )
65
-
66
- example_prompt = 'Here are some example LEGO models:'
67
- example_instructions = '\n\n'.join(self._create_example_instruction(x) for x in self.few_shot_examples)
68
- few_shot_instructions = (
69
- 'Do NOT copy the examples, but create your own LEGO model for the following input.\n\n'
70
- '### Input:\n'
71
- f'{caption}\n\n'
72
- '### Output:\n'
73
- )
74
-
75
- return '\n\n'.join([base_instruction, zero_shot_instructions, example_prompt,
76
- example_instructions, few_shot_instructions])
77
-
78
- def _create_example_instruction(self, x: dict) -> str:
79
- return f'### Input:\n{x["caption"]}\n\n### Output:\n{x["bricks"]}'
80
-
81
- def extract_lego_instructions(self, text: str) -> List[str]:
82
- """Extract LEGO brick instructions from generated text"""
83
- instructions = []
84
- lines = text.split('\n')
85
-
86
- for line in lines:
87
- line = line.strip()
88
- if not line:
89
- continue
90
-
91
- # Look for BrickGPT format: "NxM (x,y,z)"
92
- brick_pattern = r'(\d+x\d+)\s*\((\d+),(\d+),(\d+)\)'
93
- match = re.search(brick_pattern, line)
94
- if match:
95
- instructions.append(line)
96
-
97
- return instructions
98
 
99
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
100
  """
101
- Process inference request EXACTLY like local BrickGPT does
102
  """
103
  inputs = data.pop("inputs", data)
104
  parameters = data.pop("parameters", {})
105
 
106
- # Handle different input formats that BrickGPT sends
107
  if isinstance(inputs, dict) and "messages" in inputs:
108
  messages = inputs["messages"]
109
  elif isinstance(inputs, list):
110
  messages = inputs
111
- elif isinstance(inputs, str):
 
112
  messages = [
113
  {"role": "system", "content": "You are a helpful assistant."},
114
- {"role": "user", "content": inputs}
115
  ]
116
- else:
117
- messages = [{"role": "user", "content": str(inputs)}]
118
 
119
  # Check if this is a continuation (has assistant message)
120
  has_assistant = any(msg.get("role") == "assistant" for msg in messages)
121
 
122
- # Format prompt EXACTLY like local BrickGPT does
123
  if has_assistant:
124
- # For continuation, use continue_final_message=True and return tensors
125
  prompt = self.tokenizer.apply_chat_template(
126
  messages,
127
  continue_final_message=True,
128
  return_tensors='pt'
129
  )
130
  else:
131
- # For new generation, add generation prompt and return tensors
132
  prompt = self.tokenizer.apply_chat_template(
133
  messages,
134
  add_generation_prompt=True,
@@ -139,20 +65,19 @@ class EndpointHandler:
139
  input_ids = prompt.to(self.model.device)
140
  attention_mask = torch.ones_like(input_ids)
141
 
142
- # Generate EXACTLY like local BrickGPT's generate_brick method
143
- # Local BrickGPT uses max_new_tokens=10 for single brick generation
144
  generation_params = {
145
- "max_new_tokens": 10, # EXACTLY like local BrickGPT
146
  "temperature": parameters.get("temperature", 0.6),
147
  "top_k": parameters.get("top_k", 20),
148
  "top_p": parameters.get("top_p", 1.0),
149
  "pad_token_id": self.tokenizer.pad_token_id,
150
- "do_sample": True, # EXACTLY like local LLM
151
- "num_return_sequences": 1, # EXACTLY like local LLM
152
  "return_dict_in_generate": True,
153
  }
154
 
155
- # Generate exactly like the local LLM class
156
  with torch.no_grad():
157
  output_dict = self.model.generate(
158
  input_ids,
@@ -160,17 +85,12 @@ class EndpointHandler:
160
  **generation_params
161
  )
162
 
163
- # Decode EXACTLY like local BrickGPT does
164
  input_length = input_ids.shape[1]
165
  result_ids = output_dict['sequences'][0][input_length:]
166
 
167
- # Local BrickGPT uses skip_special_tokens=True in generate_brick methods
168
- generated_text = self.tokenizer.decode(result_ids, skip_special_tokens=True)
169
-
170
- # Extract LEGO instructions (same as before)
171
- lego_instructions = self.extract_lego_instructions(generated_text)
172
 
173
- return [{
174
- "generated_text": generated_text,
175
- "lego_instructions": lego_instructions
176
- }]
 
1
  from typing import Dict, List, Any
2
  import torch
 
3
  import os
 
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
 
 
24
  # Set pad token if not exists
25
  if self.tokenizer.pad_token is None:
26
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
  """
30
+ Simple handler that mimics local LLM behavior for RemoteLLM
31
  """
32
  inputs = data.pop("inputs", data)
33
  parameters = data.pop("parameters", {})
34
 
35
+ # Handle different input formats that RemoteLLM sends
36
  if isinstance(inputs, dict) and "messages" in inputs:
37
  messages = inputs["messages"]
38
  elif isinstance(inputs, list):
39
  messages = inputs
40
+ else:
41
+ # Fallback - treat as direct text
42
  messages = [
43
  {"role": "system", "content": "You are a helpful assistant."},
44
+ {"role": "user", "content": str(inputs)}
45
  ]
 
 
46
 
47
  # Check if this is a continuation (has assistant message)
48
  has_assistant = any(msg.get("role") == "assistant" for msg in messages)
49
 
50
+ # Apply chat template exactly like BrickGPT does locally
51
  if has_assistant:
 
52
  prompt = self.tokenizer.apply_chat_template(
53
  messages,
54
  continue_final_message=True,
55
  return_tensors='pt'
56
  )
57
  else:
 
58
  prompt = self.tokenizer.apply_chat_template(
59
  messages,
60
  add_generation_prompt=True,
 
65
  input_ids = prompt.to(self.model.device)
66
  attention_mask = torch.ones_like(input_ids)
67
 
68
+ # Generation parameters - use BrickGPT defaults
 
69
  generation_params = {
70
+ "max_new_tokens": parameters.get("max_new_tokens", 10),
71
  "temperature": parameters.get("temperature", 0.6),
72
  "top_k": parameters.get("top_k", 20),
73
  "top_p": parameters.get("top_p", 1.0),
74
  "pad_token_id": self.tokenizer.pad_token_id,
75
+ "do_sample": True,
76
+ "num_return_sequences": 1,
77
  "return_dict_in_generate": True,
78
  }
79
 
80
+ # Generate
81
  with torch.no_grad():
82
  output_dict = self.model.generate(
83
  input_ids,
 
85
  **generation_params
86
  )
87
 
88
+ # Extract new tokens and decode EXACTLY like local LLM
89
  input_length = input_ids.shape[1]
90
  result_ids = output_dict['sequences'][0][input_length:]
91
 
92
+ # CRITICAL: Decode exactly like local LLM (no skip_special_tokens parameter)
93
+ generated_text = self.tokenizer.decode(result_ids)
 
 
 
94
 
95
+ # Return in format RemoteLLM expects
96
+ return [{"generated_text": generated_text}]