david167 commited on
Commit
b2df124
·
1 Parent(s): 0cdc4eb

Fix Gradio interface: Remove chatbot format issues, add proper API endpoint structure

Browse files
Files changed (2) hide show
  1. gradio_app.py +47 -15
  2. gradio_app_new.py +214 -0
gradio_app.py CHANGED
@@ -184,31 +184,63 @@ def generate_response(prompt, temperature=0.8, model_manager=None):
184
  model_manager = ModelManager()
185
 
186
  def respond(message, history, temperature):
187
- """Gradio interface function"""
188
  try:
189
  response = generate_response(message, temperature, model_manager)
190
- history.append([message, response])
191
- return history, ""
192
  except Exception as e:
193
  logger.error(f"Error in respond: {e}")
194
- history.append([message, f"Error: {e}"])
195
- return history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  # Create Gradio interface
198
  with gr.Blocks(title="Question Generation API") as demo:
199
- gr.Markdown("# Question Generation API")
200
-
201
- chatbot = gr.Chatbot(height=400)
202
- msg = gr.Textbox(label="Message", placeholder="Enter your prompt...")
203
- temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature")
204
 
205
  with gr.Row():
206
- submit = gr.Button("Submit", variant="primary")
207
- clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
210
- msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
211
- clear.click(lambda: ([], ""), outputs=[chatbot, msg])
 
 
 
 
 
 
 
 
212
 
213
  if __name__ == "__main__":
214
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
184
  model_manager = ModelManager()
185
 
186
  def respond(message, history, temperature):
187
+ """Gradio interface function - fixed for proper format"""
188
  try:
189
  response = generate_response(message, temperature, model_manager)
190
+ # Return just the response for the simple interface
191
+ return response
192
  except Exception as e:
193
  logger.error(f"Error in respond: {e}")
194
+ return f"Error: {e}"
195
+
196
+ # API function for external calls
197
+ def api_respond(message, history=None, temperature=0.8, json_mode=None, template=None):
198
+ """API endpoint matching original client expectations"""
199
+ try:
200
+ response = generate_response(message, temperature, model_manager)
201
+
202
+ # Return in original format that client expects
203
+ return [[
204
+ {"role": "user", "metadata": None, "content": message, "options": None},
205
+ {"role": "assistant", "metadata": None, "content": response, "options": None}
206
+ ], ""]
207
+ except Exception as e:
208
+ logger.error(f"API Error: {e}")
209
+ return [[
210
+ {"role": "user", "metadata": None, "content": message, "options": None},
211
+ {"role": "assistant", "metadata": None, "content": f"Error: {e}", "options": None}
212
+ ], ""]
213
 
214
  # Create Gradio interface
215
  with gr.Blocks(title="Question Generation API") as demo:
216
+ gr.Markdown("# Question Generation API - Elegant Architecture")
 
 
 
 
217
 
218
  with gr.Row():
219
+ with gr.Column():
220
+ message_input = gr.Textbox(label="Message", placeholder="Enter your prompt...", lines=5)
221
+ temperature_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature")
222
+ submit_btn = gr.Button("Generate", variant="primary")
223
+
224
+ with gr.Column():
225
+ response_output = gr.Textbox(label="Response", lines=15, max_lines=30)
226
+
227
+ # Simple UI function
228
+ def ui_respond(message, temperature):
229
+ return generate_response(message, temperature, model_manager)
230
+
231
+ submit_btn.click(ui_respond, inputs=[message_input, temperature_input], outputs=[response_output])
232
 
233
+ # Create the API interface that matches the original /respond endpoint
234
+ api_interface = gr.Interface(
235
+ fn=api_respond,
236
+ inputs=[
237
+ gr.Textbox(label="message"),
238
+ gr.State(value=[]), # history
239
+ gr.Number(value=0.8, label="temperature")
240
+ ],
241
+ outputs=gr.JSON(),
242
+ api_name="respond"
243
+ )
244
 
245
  if __name__ == "__main__":
246
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
gradio_app_new.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+ import json
7
+ import re
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class ModelManager:
14
+ def __init__(self):
15
+ self.model = None
16
+ self.tokenizer = None
17
+ self.device = None
18
+ self.model_loaded = False
19
+ self.load_model()
20
+
21
+ def load_model(self):
22
+ """Load the model and tokenizer"""
23
+ try:
24
+ logger.info("Starting model loading...")
25
+
26
+ # Check if CUDA is available
27
+ if torch.cuda.is_available():
28
+ torch.cuda.set_device(0)
29
+ self.device = "cuda:0"
30
+ else:
31
+ self.device = "cpu"
32
+ logger.info(f"Using device: {self.device}")
33
+
34
+ if self.device == "cuda:0":
35
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
36
+ logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
37
+
38
+ # Get HF token from environment
39
+ hf_token = os.getenv("HF_TOKEN")
40
+
41
+ logger.info("Loading Llama-3.1-8B-Instruct model...")
42
+ base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
43
+
44
+ self.tokenizer = AutoTokenizer.from_pretrained(
45
+ base_model_name,
46
+ use_fast=True,
47
+ trust_remote_code=True,
48
+ token=hf_token
49
+ )
50
+
51
+ self.model = AutoModelForCausalLM.from_pretrained(
52
+ base_model_name,
53
+ torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
54
+ device_map="auto" if self.device == "cuda:0" else None,
55
+ trust_remote_code=True,
56
+ token=hf_token
57
+ )
58
+
59
+ # Set pad token
60
+ if self.tokenizer.pad_token is None:
61
+ self.tokenizer.pad_token = self.tokenizer.eos_token
62
+
63
+ self.model_loaded = True
64
+ logger.info("✅ Model loaded successfully!")
65
+
66
+ except Exception as e:
67
+ logger.error(f"❌ Error loading model: {str(e)}")
68
+ self.model_loaded = False
69
+
70
+ def generate_response(prompt, temperature=0.8, model_manager=None):
71
+ """ELEGANT AI ARCHITECT SOLUTION - Clean, simple, effective"""
72
+ if not model_manager or not model_manager.model_loaded:
73
+ return "Model not loaded"
74
+
75
+ try:
76
+ # Detect request type
77
+ is_cot_request = any(phrase in prompt.lower() for phrase in [
78
+ "return exactly this json array",
79
+ "chain of thinking",
80
+ "verbatim",
81
+ "json array (no other text)"
82
+ ])
83
+
84
+ # Get actual model context
85
+ max_context = getattr(model_manager.model.config, "max_position_embeddings", 8192)
86
+ logger.info(f"Model context: {max_context} tokens")
87
+
88
+ # SIMPLE, CLEAR PROMPT FORMATTING
89
+ if is_cot_request:
90
+ system_msg = "You are an expert at generating JSON training data. Return only valid JSON arrays as requested, no additional text."
91
+ else:
92
+ system_msg = "You are a helpful AI assistant generating high-quality training data."
93
+
94
+ formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
95
+
96
+ {system_msg}
97
+
98
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
99
+
100
+ {prompt}
101
+
102
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
103
+
104
+ """
105
+
106
+ # SMART TOKEN ALLOCATION
107
+ if is_cot_request:
108
+ # CoT needs substantial output for complete JSON
109
+ max_new_tokens = 3000 # Generous but not excessive
110
+ min_new_tokens = 500 # Ensure JSON completion
111
+ else:
112
+ max_new_tokens = 1500
113
+ min_new_tokens = 50
114
+
115
+ # Reserve space for input
116
+ max_input_tokens = max_context - max_new_tokens - 100
117
+
118
+ logger.info(f"Token plan: Input≤{max_input_tokens}, Output={min_new_tokens}-{max_new_tokens}")
119
+
120
+ # Tokenize
121
+ inputs = model_manager.tokenizer(
122
+ formatted_prompt,
123
+ return_tensors="pt",
124
+ truncation=True,
125
+ max_length=max_input_tokens
126
+ )
127
+
128
+ # Move to device
129
+ if model_manager.device == "cuda:0":
130
+ inputs = {k: v.to(next(model_manager.model.parameters()).device) for k, v in inputs.items()}
131
+
132
+ # CLEAN GENERATION
133
+ with torch.no_grad():
134
+ outputs = model_manager.model.generate(
135
+ **inputs,
136
+ max_new_tokens=max_new_tokens,
137
+ min_new_tokens=min_new_tokens,
138
+ temperature=temperature,
139
+ top_p=0.9,
140
+ do_sample=True,
141
+ pad_token_id=model_manager.tokenizer.eos_token_id,
142
+ early_stopping=False,
143
+ repetition_penalty=1.1
144
+ )
145
+
146
+ # Decode
147
+ full_response = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
148
+
149
+ # Log stats
150
+ input_len = inputs['input_ids'].shape[1]
151
+ output_len = outputs[0].shape[0]
152
+ generated_len = output_len - input_len
153
+ logger.info(f"Generated {generated_len} tokens (min was {min_new_tokens})")
154
+
155
+ # CLEAN EXTRACTION
156
+ if "<|start_header_id|>assistant<|end_header_id|>" in full_response:
157
+ response = full_response.split("<|start_header_id|>assistant<|end_header_id|>", 1)[-1].strip()
158
+ else:
159
+ # Fallback
160
+ response = full_response[len(formatted_prompt):].strip()
161
+
162
+ # For CoT, extract clean JSON if possible
163
+ if is_cot_request and '[' in response and ']' in response:
164
+ # Find the most complete JSON array
165
+ json_pattern = r'\[(?:[^[\]]+|\[[^\]]*\])*\]'
166
+ matches = re.findall(json_pattern, response, re.DOTALL)
167
+
168
+ if matches:
169
+ # Pick the longest match (most complete)
170
+ best_match = max(matches, key=len)
171
+ # Verify it has reasonable content
172
+ if '"user"' in best_match and '"assistant"' in best_match:
173
+ logger.info(f"Extracted JSON: {len(best_match)} chars")
174
+ response = best_match
175
+
176
+ logger.info(f"Final response: {len(response)} chars")
177
+ return response.strip()
178
+
179
+ except Exception as e:
180
+ logger.error(f"Generation error: {e}")
181
+ return f"Error: {e}"
182
+
183
+ # Initialize model
184
+ model_manager = ModelManager()
185
+
186
+ def respond(message, history, temperature):
187
+ """Gradio interface function"""
188
+ try:
189
+ response = generate_response(message, temperature, model_manager)
190
+ history.append([message, response])
191
+ return history, ""
192
+ except Exception as e:
193
+ logger.error(f"Error in respond: {e}")
194
+ history.append([message, f"Error: {e}"])
195
+ return history, ""
196
+
197
+ # Create Gradio interface
198
+ with gr.Blocks(title="Question Generation API") as demo:
199
+ gr.Markdown("# Question Generation API")
200
+
201
+ chatbot = gr.Chatbot(height=400)
202
+ msg = gr.Textbox(label="Message", placeholder="Enter your prompt...")
203
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature")
204
+
205
+ with gr.Row():
206
+ submit = gr.Button("Submit", variant="primary")
207
+ clear = gr.Button("Clear")
208
+
209
+ submit.click(respond, [msg, chatbot, temperature], [chatbot, msg])
210
+ msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
211
+ clear.click(lambda: ([], ""), outputs=[chatbot, msg])
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)