david167 commited on
Commit
0cdc4eb
·
1 Parent(s): 14f445d

ELEGANT API REWRITE: Clean architecture, smart token allocation, proper JSON extraction - eliminate placeholder generation

Browse files
Files changed (2) hide show
  1. gradio_app.py +107 -215
  2. gradio_app_old.py +322 -0
gradio_app.py CHANGED
@@ -3,6 +3,8 @@ import logging
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
 
 
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
@@ -49,34 +51,51 @@ class ModelManager:
49
  self.model = AutoModelForCausalLM.from_pretrained(
50
  base_model_name,
51
  torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
52
- device_map={"": 0} if self.device == "cuda:0" else None,
53
  trust_remote_code=True,
54
- low_cpu_mem_usage=True,
55
- use_safetensors=True,
56
  token=hf_token
57
  )
58
 
59
- if self.device == "cuda:0":
60
- self.model = self.model.to(self.device)
 
61
 
62
  self.model_loaded = True
63
- logger.info("Model loaded successfully!")
64
 
65
  except Exception as e:
66
- logger.error(f"Error loading model: {str(e)}")
67
  self.model_loaded = False
68
 
69
- # Initialize model manager
70
- model_manager = ModelManager()
 
 
71
 
72
- def generate_response(prompt, temperature=0.8):
73
- """Simple function to generate a response from a prompt"""
74
- if not model_manager.model_loaded:
75
- return "Model not loaded yet. Please wait..."
76
-
77
  try:
78
- # Create the Llama-3.1 chat format
79
- formatted_prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  {prompt}
82
 
@@ -84,239 +103,112 @@ def generate_response(prompt, temperature=0.8):
84
 
85
  """
86
 
87
- # Determine context window and USE ABSOLUTE MAXIMUM
88
- try:
89
- max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k
90
- except Exception:
91
- max_ctx = 131072 # Use maximum possible
92
-
93
- logger.info(f"Model max context: {max_ctx} tokens")
94
-
95
- # Detect if this is a Chain of Thinking request
96
- is_cot_request = ("chain-of-thinking" in prompt.lower() or
97
- "chain of thinking" in prompt.lower() or
98
- "Return exactly this JSON array" in prompt or
99
- ("verbatim" in prompt.lower() and "json array" in prompt.lower()))
100
-
101
- # MAXIMIZE GENERATION TOKENS - use most of context for generation
102
  if is_cot_request:
103
- # For CoT, use MAXIMUM possible generation tokens
104
- gen_max_new_tokens = 16384 # Very high limit for complete responses
105
- min_tokens = 2000 # High minimum to force complete generation
106
- # Allow most of context for input
107
- allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer
108
- logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}")
109
  else:
110
- # Standard requests
111
- gen_max_new_tokens = 8192
112
- min_tokens = 200
113
- allowed_input_tokens = max_ctx - gen_max_new_tokens - 100
 
 
 
114
 
115
- # Tokenize the input with safe truncation
116
  inputs = model_manager.tokenizer(
117
  formatted_prompt,
118
  return_tensors="pt",
119
  truncation=True,
120
- max_length=allowed_input_tokens
121
  )
122
 
123
- # Move inputs to the same device as the model
124
  if model_manager.device == "cuda:0":
125
- model_device = next(model_manager.model.parameters()).device
126
- inputs = {k: v.to(model_device) for k, v in inputs.items()}
127
 
128
- # Generate response with MAXIMUM settings
129
  with torch.no_grad():
130
  outputs = model_manager.model.generate(
131
  **inputs,
132
- max_new_tokens=gen_max_new_tokens,
133
- min_new_tokens=min_tokens,
134
  temperature=temperature,
135
- top_p=0.95,
136
  do_sample=True,
137
- num_beams=1,
138
  pad_token_id=model_manager.tokenizer.eos_token_id,
139
- eos_token_id=model_manager.tokenizer.eos_token_id,
140
- early_stopping=False, # Never stop early
141
- repetition_penalty=1.05,
142
- no_repeat_ngram_size=0,
143
- length_penalty=1.0,
144
- # Force generation to continue
145
- use_cache=True
146
  )
147
 
148
- # Decode the response
149
- generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
150
-
151
- # Log generation details for debugging
152
- input_length = inputs['input_ids'].shape[1]
153
- output_length = outputs[0].shape[0]
154
- generated_length = output_length - input_length
155
- logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}")
156
 
157
- if generated_length < min_tokens:
158
- logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated")
159
-
160
- # Post-decode guard: if a top-level JSON array closes, trim to the first full array
161
- # This helps prevent trailing prose like 'assistant' or 'Message'.
162
- try:
163
- # Track both bracket and brace depth to find first complete JSON structure
164
- bracket_depth = 0 # [ ]
165
- brace_depth = 0 # { }
166
- in_string = False
167
- escape_next = False
168
- start_idx = None
169
- end_idx = None
170
-
171
- for i, ch in enumerate(generated_text):
172
- # Handle string escaping
173
- if escape_next:
174
- escape_next = False
175
- continue
176
-
177
- if ch == '\\':
178
- escape_next = True
179
- continue
180
-
181
- # Track if we're inside a string
182
- if ch == '"' and not escape_next:
183
- in_string = not in_string
184
- continue
185
-
186
- # Only count brackets/braces outside of strings
187
- if not in_string:
188
- if ch == '[':
189
- if bracket_depth == 0 and brace_depth == 0 and start_idx is None:
190
- start_idx = i
191
- bracket_depth += 1
192
- elif ch == ']':
193
- bracket_depth = max(0, bracket_depth - 1)
194
- if bracket_depth == 0 and brace_depth == 0 and start_idx is not None:
195
- end_idx = i
196
- break
197
- elif ch == '{':
198
- brace_depth += 1
199
- elif ch == '}':
200
- brace_depth = max(0, brace_depth - 1)
201
-
202
- if start_idx is not None and end_idx is not None and end_idx > start_idx:
203
- # Extract just the complete JSON array
204
- json_text = generated_text[start_idx:end_idx+1]
205
- logger.info(f"Extracted complete JSON array of length {len(json_text)}")
206
- generated_text = json_text
207
- elif start_idx is not None:
208
- # Found start but no end - response was truncated
209
- logger.warning("JSON array started but never closed - response truncated")
210
- # Try to extract what we have and let the client handle it
211
- generated_text = generated_text[start_idx:]
212
- except Exception as e:
213
- logger.warning(f"Error in JSON extraction: {e}")
214
- pass
215
 
216
- # Extract just the assistant's response
217
- if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
218
- response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
219
  else:
220
- # Better fallback: look for the start of actual content (JSON or text)
221
- import re
222
-
223
- # Look for JSON array or object start
224
- json_match = re.search(r'(\[|\{)', generated_text)
225
- if json_match and json_match.start() > len(formatted_prompt) // 2:
226
- response = generated_text[json_match.start():].strip()
227
- else:
228
- # Look for the end of the prompt pattern
229
- prompt_end_patterns = [
230
- "<|end_header_id|>",
231
- "<|eot_id|>",
232
- "assistant",
233
- "\n\n"
234
- ]
235
-
236
- response = generated_text
237
- for pattern in prompt_end_patterns:
238
- if pattern in generated_text:
239
- parts = generated_text.split(pattern)
240
- if len(parts) > 1:
241
- # Take the last substantial part
242
- candidate = parts[-1].strip()
243
- if len(candidate) > 20: # Ensure it's not too short
244
- response = candidate
245
- break
246
-
247
- # Ultimate fallback - just return everything after a reasonable point
248
- if response == generated_text:
249
- # Skip approximately the prompt length but be conservative
250
- skip_chars = min(len(formatted_prompt) // 2, len(generated_text) // 3)
251
- response = generated_text[skip_chars:].strip()
252
 
253
- logger.info(f"Generated response length: {len(response)} characters")
254
- return response
 
 
 
 
 
 
 
 
 
 
 
255
 
 
 
 
256
  except Exception as e:
257
- logger.error(f"Error generating response: {str(e)}")
258
- return f"Error: {str(e)}"
 
 
 
259
 
260
  def respond(message, history, temperature):
261
- """Gradio interface function for chat"""
262
- response = generate_response(message, temperature)
263
-
264
- # Update history
265
- history.append({"role": "user", "content": message})
266
- history.append({"role": "assistant", "content": response})
267
-
268
- return history, ""
 
269
 
270
- # Create the Gradio interface
271
  with gr.Blocks(title="Question Generation API") as demo:
272
- gr.Markdown("# Simple LLM API")
273
- gr.Markdown("Send a prompt and get a response. No templates, just direct model interaction.")
 
 
 
274
 
275
  with gr.Row():
276
- with gr.Column(scale=4):
277
- chatbot = gr.Chatbot(
278
- label="Chat",
279
- type="messages",
280
- height=400
281
- )
282
- msg = gr.Textbox(
283
- label="Message",
284
- placeholder="Enter your prompt here...",
285
- lines=3
286
- )
287
- with gr.Row():
288
- submit = gr.Button("Send", variant="primary")
289
- clear = gr.Button("Clear")
290
-
291
- with gr.Column(scale=1):
292
- temperature = gr.Slider(
293
- minimum=0.1,
294
- maximum=2.0,
295
- value=0.8,
296
- step=0.1,
297
- label="Temperature",
298
- info="Higher = more creative"
299
- )
300
- gr.Markdown("""
301
- ### API Usage
302
- This model accepts any prompt and returns a response.
303
-
304
- For JSON responses, include instructions in your prompt like:
305
- - "Return as a JSON array"
306
- - "Format as JSON"
307
- - "List as JSON"
308
-
309
- The model will follow your instructions.
310
- """)
311
 
312
- # Set up event handlers
313
  submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
314
  msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
315
  clear.click(lambda: ([], ""), outputs=[chatbot, msg])
316
 
317
  if __name__ == "__main__":
318
- demo.launch(
319
- server_name="0.0.0.0",
320
- server_port=7860,
321
- share=False
322
- )
 
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import gradio as gr
6
+ import json
7
+ import re
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
 
51
  self.model = AutoModelForCausalLM.from_pretrained(
52
  base_model_name,
53
  torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
54
+ device_map="auto" if self.device == "cuda:0" else None,
55
  trust_remote_code=True,
 
 
56
  token=hf_token
57
  )
58
 
59
+ # Set pad token
60
+ if self.tokenizer.pad_token is None:
61
+ self.tokenizer.pad_token = self.tokenizer.eos_token
62
 
63
  self.model_loaded = True
64
+ logger.info("Model loaded successfully!")
65
 
66
  except Exception as e:
67
+ logger.error(f"Error loading model: {str(e)}")
68
  self.model_loaded = False
69
 
70
+ def generate_response(prompt, temperature=0.8, model_manager=None):
71
+ """ELEGANT AI ARCHITECT SOLUTION - Clean, simple, effective"""
72
+ if not model_manager or not model_manager.model_loaded:
73
+ return "Model not loaded"
74
 
 
 
 
 
 
75
  try:
76
+ # Detect request type
77
+ is_cot_request = any(phrase in prompt.lower() for phrase in [
78
+ "return exactly this json array",
79
+ "chain of thinking",
80
+ "verbatim",
81
+ "json array (no other text)"
82
+ ])
83
+
84
+ # Get actual model context
85
+ max_context = getattr(model_manager.model.config, "max_position_embeddings", 8192)
86
+ logger.info(f"Model context: {max_context} tokens")
87
+
88
+ # SIMPLE, CLEAR PROMPT FORMATTING
89
+ if is_cot_request:
90
+ system_msg = "You are an expert at generating JSON training data. Return only valid JSON arrays as requested, no additional text."
91
+ else:
92
+ system_msg = "You are a helpful AI assistant generating high-quality training data."
93
+
94
+ formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
95
+
96
+ {system_msg}
97
+
98
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
99
 
100
  {prompt}
101
 
 
103
 
104
  """
105
 
106
+ # SMART TOKEN ALLOCATION
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if is_cot_request:
108
+ # CoT needs substantial output for complete JSON
109
+ max_new_tokens = 3000 # Generous but not excessive
110
+ min_new_tokens = 500 # Ensure JSON completion
 
 
 
111
  else:
112
+ max_new_tokens = 1500
113
+ min_new_tokens = 50
114
+
115
+ # Reserve space for input
116
+ max_input_tokens = max_context - max_new_tokens - 100
117
+
118
+ logger.info(f"Token plan: Input≤{max_input_tokens}, Output={min_new_tokens}-{max_new_tokens}")
119
 
120
+ # Tokenize
121
  inputs = model_manager.tokenizer(
122
  formatted_prompt,
123
  return_tensors="pt",
124
  truncation=True,
125
+ max_length=max_input_tokens
126
  )
127
 
128
+ # Move to device
129
  if model_manager.device == "cuda:0":
130
+ inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()}
 
131
 
132
+ # CLEAN GENERATION
133
  with torch.no_grad():
134
  outputs = model_manager.model.generate(
135
  **inputs,
136
+ max_new_tokens=max_new_tokens,
137
+ min_new_tokens=min_new_tokens,
138
  temperature=temperature,
139
+ top_p=0.9,
140
  do_sample=True,
 
141
  pad_token_id=model_manager.tokenizer.eos_token_id,
142
+ early_stopping=False,
143
+ repetition_penalty=1.1
 
 
 
 
 
144
  )
145
 
146
+ # Decode
147
+ full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
148
 
149
+ # Log stats
150
+ input_len = inputs['input_ids'].shape[1]
151
+ output_len = outputs[0].shape[0]
152
+ generated_len = output_len - input_len
153
+ logger.info(f"Generated {generated_len} tokens (min was {min_new_tokens})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # CLEAN EXTRACTION
156
+ if "<|start_header_id|>assistant<|end_header_id|>" in full_response:
157
+ response = full_response.split("<|start_header_id|>assistant<|end_header_id|>", 1)[-1].strip()
158
  else:
159
+ # Fallback
160
+ response = full_response[len(formatted_prompt):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # For CoT, extract clean JSON if possible
163
+ if is_cot_request and '[' in response and ']' in response:
164
+ # Find the most complete JSON array
165
+ json_pattern = r'\[(?:[^[\]]+|\[[^\]]*\])*\]'
166
+ matches = re.findall(json_pattern, response, re.DOTALL)
167
+
168
+ if matches:
169
+ # Pick the longest match (most complete)
170
+ best_match = max(matches, key=len)
171
+ # Verify it has reasonable content
172
+ if '"user"' in best_match and '"assistant"' in best_match:
173
+ logger.info(f"Extracted JSON: {len(best_match)} chars")
174
+ response = best_match
175
 
176
+ logger.info(f"Final response: {len(response)} chars")
177
+ return response.strip()
178
+
179
  except Exception as e:
180
+ logger.error(f"Generation error: {e}")
181
+ return f"Error: {e}"
182
+
183
+ # Initialize model
184
+ model_manager = ModelManager()
185
 
186
  def respond(message, history, temperature):
187
+ """Gradio interface function"""
188
+ try:
189
+ response = generate_response(message, temperature, model_manager)
190
+ history.append([message, response])
191
+ return history, ""
192
+ except Exception as e:
193
+ logger.error(f"Error in respond: {e}")
194
+ history.append([message, f"Error: {e}"])
195
+ return history, ""
196
 
197
+ # Create Gradio interface
198
  with gr.Blocks(title="Question Generation API") as demo:
199
+ gr.Markdown("# Question Generation API")
200
+
201
+ chatbot = gr.Chatbot(height=400)
202
+ msg = gr.Textbox(label="Message", placeholder="Enter your prompt...")
203
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature")
204
 
205
  with gr.Row():
206
+ submit = gr.Button("Submit", variant="primary")
207
+ clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
 
209
  submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
210
  msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
211
  clear.click(lambda: ([], ""), outputs=[chatbot, msg])
212
 
213
  if __name__ == "__main__":
214
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
gradio_app_old.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class ModelManager:
12
+ def __init__(self):
13
+ self.model = None
14
+ self.tokenizer = None
15
+ self.device = None
16
+ self.model_loaded = False
17
+ self.load_model()
18
+
19
+ def load_model(self):
20
+ """Load the model and tokenizer"""
21
+ try:
22
+ logger.info("Starting model loading...")
23
+
24
+ # Check if CUDA is available
25
+ if torch.cuda.is_available():
26
+ torch.cuda.set_device(0)
27
+ self.device = "cuda:0"
28
+ else:
29
+ self.device = "cpu"
30
+ logger.info(f"Using device: {self.device}")
31
+
32
+ if self.device == "cuda:0":
33
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
34
+ logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
35
+
36
+ # Get HF token from environment
37
+ hf_token = os.getenv("HF_TOKEN")
38
+
39
+ logger.info("Loading Llama-3.1-8B-Instruct model...")
40
+ base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
41
+
42
+ self.tokenizer = AutoTokenizer.from_pretrained(
43
+ base_model_name,
44
+ use_fast=True,
45
+ trust_remote_code=True,
46
+ token=hf_token
47
+ )
48
+
49
+ self.model = AutoModelForCausalLM.from_pretrained(
50
+ base_model_name,
51
+ torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
52
+ device_map={"": 0} if self.device == "cuda:0" else None,
53
+ trust_remote_code=True,
54
+ low_cpu_mem_usage=True,
55
+ use_safetensors=True,
56
+ token=hf_token
57
+ )
58
+
59
+ if self.device == "cuda:0":
60
+ self.model = self.model.to(self.device)
61
+
62
+ self.model_loaded = True
63
+ logger.info("Model loaded successfully!")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Error loading model: {str(e)}")
67
+ self.model_loaded = False
68
+
69
+ # Initialize model manager
70
+ model_manager = ModelManager()
71
+
72
+ def generate_response(prompt, temperature=0.8):
73
+ """Simple function to generate a response from a prompt"""
74
+ if not model_manager.model_loaded:
75
+ return "Model not loaded yet. Please wait..."
76
+
77
+ try:
78
+ # Create the Llama-3.1 chat format
79
+ formatted_prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
80
+
81
+ {prompt}
82
+
83
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
84
+
85
+ """
86
+
87
+ # Determine context window and USE ABSOLUTE MAXIMUM
88
+ try:
89
+ max_ctx = getattr(model_manager.model.config, "max_position_embeddings", 131072) # Llama 3.1 supports up to 131k
90
+ except Exception:
91
+ max_ctx = 131072 # Use maximum possible
92
+
93
+ logger.info(f"Model max context: {max_ctx} tokens")
94
+
95
+ # Detect if this is a Chain of Thinking request
96
+ is_cot_request = ("chain-of-thinking" in prompt.lower() or
97
+ "chain of thinking" in prompt.lower() or
98
+ "Return exactly this JSON array" in prompt or
99
+ ("verbatim" in prompt.lower() and "json array" in prompt.lower()))
100
+
101
+ # MAXIMIZE GENERATION TOKENS - use most of context for generation
102
+ if is_cot_request:
103
+ # For CoT, use MAXIMUM possible generation tokens
104
+ gen_max_new_tokens = 16384 # Very high limit for complete responses
105
+ min_tokens = 2000 # High minimum to force complete generation
106
+ # Allow most of context for input
107
+ allowed_input_tokens = max_ctx - gen_max_new_tokens - 100 # Small safety buffer
108
+ logger.info(f"CoT REQUEST - MAXIMIZED: min_tokens={min_tokens}, max_new_tokens={gen_max_new_tokens}, input_limit={allowed_input_tokens}")
109
+ else:
110
+ # Standard requests
111
+ gen_max_new_tokens = 8192
112
+ min_tokens = 200
113
+ allowed_input_tokens = max_ctx - gen_max_new_tokens - 100
114
+
115
+ # Tokenize the input with safe truncation
116
+ inputs = model_manager.tokenizer(
117
+ formatted_prompt,
118
+ return_tensors="pt",
119
+ truncation=True,
120
+ max_length=allowed_input_tokens
121
+ )
122
+
123
+ # Move inputs to the same device as the model
124
+ if model_manager.device == "cuda:0":
125
+ model_device = next(model_manager.model.parameters()).device
126
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
127
+
128
+ # Generate response with MAXIMUM settings
129
+ with torch.no_grad():
130
+ outputs = model_manager.model.generate(
131
+ **inputs,
132
+ max_new_tokens=gen_max_new_tokens,
133
+ min_new_tokens=min_tokens,
134
+ temperature=temperature,
135
+ top_p=0.95,
136
+ do_sample=True,
137
+ num_beams=1,
138
+ pad_token_id=model_manager.tokenizer.eos_token_id,
139
+ eos_token_id=model_manager.tokenizer.eos_token_id,
140
+ early_stopping=False, # Never stop early
141
+ repetition_penalty=1.05,
142
+ no_repeat_ngram_size=0,
143
+ length_penalty=1.0,
144
+ # Force generation to continue
145
+ use_cache=True
146
+ )
147
+
148
+ # Decode the response
149
+ generated_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
150
+
151
+ # Log generation details for debugging
152
+ input_length = inputs['input_ids'].shape[1]
153
+ output_length = outputs[0].shape[0]
154
+ generated_length = output_length - input_length
155
+ logger.info(f"Generation stats - Input: {input_length} tokens, Generated: {generated_length} tokens, Min required: {min_tokens}")
156
+
157
+ if generated_length < min_tokens:
158
+ logger.warning(f"Generated {generated_length} tokens but minimum was {min_tokens} - response may be truncated")
159
+
160
+ # Post-decode guard: if a top-level JSON array closes, trim to the first full array
161
+ # This helps prevent trailing prose like 'assistant' or 'Message'.
162
+ try:
163
+ # Track both bracket and brace depth to find first complete JSON structure
164
+ bracket_depth = 0 # [ ]
165
+ brace_depth = 0 # { }
166
+ in_string = False
167
+ escape_next = False
168
+ start_idx = None
169
+ end_idx = None
170
+
171
+ for i, ch in enumerate(generated_text):
172
+ # Handle string escaping
173
+ if escape_next:
174
+ escape_next = False
175
+ continue
176
+
177
+ if ch == '\\':
178
+ escape_next = True
179
+ continue
180
+
181
+ # Track if we're inside a string
182
+ if ch == '"' and not escape_next:
183
+ in_string = not in_string
184
+ continue
185
+
186
+ # Only count brackets/braces outside of strings
187
+ if not in_string:
188
+ if ch == '[':
189
+ if bracket_depth == 0 and brace_depth == 0 and start_idx is None:
190
+ start_idx = i
191
+ bracket_depth += 1
192
+ elif ch == ']':
193
+ bracket_depth = max(0, bracket_depth - 1)
194
+ if bracket_depth == 0 and brace_depth == 0 and start_idx is not None:
195
+ end_idx = i
196
+ break
197
+ elif ch == '{':
198
+ brace_depth += 1
199
+ elif ch == '}':
200
+ brace_depth = max(0, brace_depth - 1)
201
+
202
+ if start_idx is not None and end_idx is not None and end_idx > start_idx:
203
+ # Extract just the complete JSON array
204
+ json_text = generated_text[start_idx:end_idx+1]
205
+ logger.info(f"Extracted complete JSON array of length {len(json_text)}")
206
+ generated_text = json_text
207
+ elif start_idx is not None:
208
+ # Found start but no end - response was truncated
209
+ logger.warning("JSON array started but never closed - response truncated")
210
+ # Try to extract what we have and let the client handle it
211
+ generated_text = generated_text[start_idx:]
212
+ except Exception as e:
213
+ logger.warning(f"Error in JSON extraction: {e}")
214
+ pass
215
+
216
+ # Extract just the assistant's response
217
+ if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
218
+ response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
219
+ else:
220
+ # Better fallback: look for the start of actual content (JSON or text)
221
+ import re
222
+
223
+ # Look for JSON array or object start
224
+ json_match = re.search(r'(\[|\{)', generated_text)
225
+ if json_match and json_match.start() > len(formatted_prompt) // 2:
226
+ response = generated_text[json_match.start():].strip()
227
+ else:
228
+ # Look for the end of the prompt pattern
229
+ prompt_end_patterns = [
230
+ "<|end_header_id|>",
231
+ "<|eot_id|>",
232
+ "assistant",
233
+ "\n\n"
234
+ ]
235
+
236
+ response = generated_text
237
+ for pattern in prompt_end_patterns:
238
+ if pattern in generated_text:
239
+ parts = generated_text.split(pattern)
240
+ if len(parts) > 1:
241
+ # Take the last substantial part
242
+ candidate = parts[-1].strip()
243
+ if len(candidate) > 20: # Ensure it's not too short
244
+ response = candidate
245
+ break
246
+
247
+ # Ultimate fallback - just return everything after a reasonable point
248
+ if response == generated_text:
249
+ # Skip approximately the prompt length but be conservative
250
+ skip_chars = min(len(formatted_prompt) // 2, len(generated_text) // 3)
251
+ response = generated_text[skip_chars:].strip()
252
+
253
+ logger.info(f"Generated response length: {len(response)} characters")
254
+ return response
255
+
256
+ except Exception as e:
257
+ logger.error(f"Error generating response: {str(e)}")
258
+ return f"Error: {str(e)}"
259
+
260
+ def respond(message, history, temperature):
261
+ """Gradio interface function for chat"""
262
+ response = generate_response(message, temperature)
263
+
264
+ # Update history
265
+ history.append({"role": "user", "content": message})
266
+ history.append({"role": "assistant", "content": response})
267
+
268
+ return history, ""
269
+
270
+ # Create the Gradio interface
271
+ with gr.Blocks(title="Question Generation API") as demo:
272
+ gr.Markdown("# Simple LLM API")
273
+ gr.Markdown("Send a prompt and get a response. No templates, just direct model interaction.")
274
+
275
+ with gr.Row():
276
+ with gr.Column(scale=4):
277
+ chatbot = gr.Chatbot(
278
+ label="Chat",
279
+ type="messages",
280
+ height=400
281
+ )
282
+ msg = gr.Textbox(
283
+ label="Message",
284
+ placeholder="Enter your prompt here...",
285
+ lines=3
286
+ )
287
+ with gr.Row():
288
+ submit = gr.Button("Send", variant="primary")
289
+ clear = gr.Button("Clear")
290
+
291
+ with gr.Column(scale=1):
292
+ temperature = gr.Slider(
293
+ minimum=0.1,
294
+ maximum=2.0,
295
+ value=0.8,
296
+ step=0.1,
297
+ label="Temperature",
298
+ info="Higher = more creative"
299
+ )
300
+ gr.Markdown("""
301
+ ### API Usage
302
+ This model accepts any prompt and returns a response.
303
+
304
+ For JSON responses, include instructions in your prompt like:
305
+ - "Return as a JSON array"
306
+ - "Format as JSON"
307
+ - "List as JSON"
308
+
309
+ The model will follow your instructions.
310
+ """)
311
+
312
+ # Set up event handlers
313
+ submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
314
+ msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
315
+ clear.click(lambda: ([], ""), outputs=[chatbot, msg])
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch(
319
+ server_name="0.0.0.0",
320
+ server_port=7860,
321
+ share=False
322
+ )