Text-to-3D
Transformers
Safetensors
English
jjohnson5253 commited on
Commit
837e58c
·
1 Parent(s): c7241e5

update handler

Browse files
Files changed (1) hide show
  1. handler.py +47 -13
handler.py CHANGED
@@ -48,34 +48,68 @@ class EndpointHandler:
48
  """
49
  Process inference request
50
  data args:
51
- inputs (:obj:`str`): The input text
52
  parameters (:obj:`Dict`, optional): Parameters for generation
53
  """
54
  inputs = data.pop("inputs", data)
55
  parameters = data.pop("parameters", {})
56
 
57
- # Format input using chat template for Llama
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
59
- # Use chat template if available
60
- messages = [{"role": "user", "content": inputs}]
61
- formatted_input = self.tokenizer.apply_chat_template(
62
- messages,
63
- tokenize=False,
64
- add_generation_prompt=True
65
- )
 
 
 
 
 
 
 
 
 
 
66
  else:
67
- # Fallback to direct input
68
- formatted_input = inputs
 
69
 
70
  # Default parameters optimized for BrickGPT
71
  generation_params = {
72
  "max_new_tokens": parameters.get("max_new_tokens", 512),
73
  "temperature": parameters.get("temperature", 0.6),
74
- "top_p": parameters.get("top_p", 0.9),
75
- "do_sample": True,
 
76
  "pad_token_id": self.tokenizer.pad_token_id,
77
  }
78
 
 
 
 
 
 
79
  # Tokenize input
80
  input_ids = self.tokenizer(formatted_input, return_tensors="pt").input_ids.to(self.model.device)
81
 
 
48
  """
49
  Process inference request
50
  data args:
51
+ inputs (:obj:`str` or :obj:`Dict`): The input text or messages
52
  parameters (:obj:`Dict`, optional): Parameters for generation
53
  """
54
  inputs = data.pop("inputs", data)
55
  parameters = data.pop("parameters", {})
56
 
57
+ # Handle different input formats that BrickGPT sends
58
+ if isinstance(inputs, dict) and "messages" in inputs:
59
+ # BrickGPT format: {"messages": [{"role": "system", ...}, {"role": "user", ...}]}
60
+ messages = inputs["messages"]
61
+ elif isinstance(inputs, list):
62
+ # Direct messages array: [{"role": "system", ...}, {"role": "user", ...}]
63
+ messages = inputs
64
+ elif isinstance(inputs, str):
65
+ # Plain string input - create default messages
66
+ messages = [
67
+ {"role": "system", "content": "You are a helpful assistant."},
68
+ {"role": "user", "content": inputs}
69
+ ]
70
+ else:
71
+ # Fallback
72
+ messages = [{"role": "user", "content": str(inputs)}]
73
+
74
+ # Format input using chat template
75
  if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
76
+ # Check if this is a continuation (has assistant message)
77
+ has_assistant = any(msg.get("role") == "assistant" for msg in messages)
78
+
79
+ if has_assistant:
80
+ # For continuation, use continue_final_message=True
81
+ formatted_input = self.tokenizer.apply_chat_template(
82
+ messages,
83
+ tokenize=False,
84
+ continue_final_message=True
85
+ )
86
+ else:
87
+ # For new generation, add generation prompt
88
+ formatted_input = self.tokenizer.apply_chat_template(
89
+ messages,
90
+ tokenize=False,
91
+ add_generation_prompt=True
92
+ )
93
  else:
94
+ # Fallback to direct input (last user message)
95
+ user_messages = [msg["content"] for msg in messages if msg.get("role") == "user"]
96
+ formatted_input = user_messages[-1] if user_messages else str(inputs)
97
 
98
  # Default parameters optimized for BrickGPT
99
  generation_params = {
100
  "max_new_tokens": parameters.get("max_new_tokens", 512),
101
  "temperature": parameters.get("temperature", 0.6),
102
+ "top_k": parameters.get("top_k", 20),
103
+ "top_p": parameters.get("top_p", 1.0),
104
+ "do_sample": parameters.get("do_sample", True),
105
  "pad_token_id": self.tokenizer.pad_token_id,
106
  }
107
 
108
+ # Add stop tokens if provided
109
+ stop_tokens = parameters.get("stop", [])
110
+ if stop_tokens:
111
+ generation_params["stop_strings"] = stop_tokens
112
+
113
  # Tokenize input
114
  input_ids = self.tokenizer(formatted_input, return_tensors="pt").input_ids.to(self.model.device)
115