david167 commited on
Commit
f364fe3
·
1 Parent(s): f093b76

Add complete JSON functionality to Gradio interface

Browse files
Files changed (1) hide show
  1. gradio_app.py +301 -0
gradio_app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import threading
4
+ import json
5
+ import re
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import gradio as gr
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Global variables for model
16
+ model = None
17
+ tokenizer = None
18
+ device = None
19
+ model_loaded = False
20
+
21
+ def load_model():
22
+ """Load the model and tokenizer"""
23
+ global model, tokenizer, device, model_loaded
24
+
25
+ try:
26
+ logger.info("Starting model loading...")
27
+
28
+ if torch.cuda.is_available():
29
+ torch.cuda.set_device(0)
30
+ device = "cuda:0"
31
+ else:
32
+ device = "cpu"
33
+ logger.info(f"Using device: {device}")
34
+
35
+ if device == "cuda:0":
36
+ logger.info(f"GPU: {torch.cuda.get_device_name()}")
37
+ logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
38
+
39
+ hf_token = os.getenv("HF_TOKEN")
40
+
41
+ logger.info("Loading Llama-3.1-8B-Instruct model...")
42
+ base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
43
+
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ base_model_name,
46
+ use_fast=True,
47
+ trust_remote_code=True,
48
+ token=hf_token
49
+ )
50
+
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ base_model_name,
53
+ torch_dtype=torch.float16 if device == "cuda:0" else torch.float32,
54
+ device_map={"": 0},
55
+ trust_remote_code=True,
56
+ low_cpu_mem_usage=True,
57
+ use_safetensors=True,
58
+ token=hf_token
59
+ )
60
+
61
+ if device == "cuda:0":
62
+ model = model.to(device)
63
+
64
+ model_loaded = True
65
+ logger.info("Model loaded successfully!")
66
+
67
+ except Exception as e:
68
+ logger.error(f"Error loading model: {str(e)}")
69
+ model_loaded = False
70
+
71
+ # Start model loading in a separate thread
72
+ model_thread = threading.Thread(target=load_model)
73
+ model_thread.start()
74
+
75
+ def create_json_prompt(message, template_type):
76
+ """Create JSON-formatted prompts based on template type"""
77
+
78
+ json_templates = {
79
+ "general": {
80
+ "instruction": "Please respond in valid JSON format with the following structure:",
81
+ "schema": """{
82
+ "response": "your main response here",
83
+ "type": "answer|question|explanation|analysis",
84
+ "confidence": 0.95,
85
+ "metadata": {
86
+ "topic": "detected topic",
87
+ "complexity": "simple|moderate|complex"
88
+ }
89
+ }"""
90
+ },
91
+ "questions": {
92
+ "instruction": "Generate 5 thoughtful questions based on the following statement. Respond in JSON format:",
93
+ "schema": """{
94
+ "questions": [
95
+ "Question 1 here?",
96
+ "Question 2 here?",
97
+ "Question 3 here?",
98
+ "Question 4 here?",
99
+ "Question 5 here?"
100
+ ],
101
+ "statement": "original statement",
102
+ "difficulty": "mixed",
103
+ "total_questions": 5,
104
+ "metadata": {
105
+ "topic": "detected topic",
106
+ "question_types": ["factual", "analytical", "creative"]
107
+ }
108
+ }"""
109
+ }
110
+ }
111
+
112
+ template = json_templates.get(template_type, json_templates["general"])
113
+
114
+ return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
115
+
116
+ {message}
117
+
118
+ {template["instruction"]}
119
+
120
+ {template["schema"]}
121
+
122
+ Ensure the response is valid JSON that can be parsed. Do not include any text outside the JSON structure.
123
+
124
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
125
+
126
+ """
127
+
128
+ def prettify_json_response(response_text):
129
+ """Try to extract and prettify JSON from response"""
130
+ try:
131
+ json_pattern = r'\{.*\}'
132
+ json_match = re.search(json_pattern, response_text, re.DOTALL)
133
+
134
+ if json_match:
135
+ json_str = json_match.group()
136
+ parsed_json = json.loads(json_str)
137
+ return json.dumps(parsed_json, indent=2, ensure_ascii=False)
138
+ else:
139
+ return response_text
140
+ except (json.JSONDecodeError, AttributeError):
141
+ return response_text
142
+
143
+ def chat_with_model(message, history, temperature, json_mode=False, json_template="general"):
144
+ """Chat function for model interaction"""
145
+ if not message.strip():
146
+ return history, ""
147
+
148
+ if not model_loaded:
149
+ response = "Model not loaded yet. Please wait..."
150
+ history.append({"role": "user", "content": message})
151
+ history.append({"role": "assistant", "content": response})
152
+ return history, ""
153
+
154
+ try:
155
+ if json_mode:
156
+ prompt = create_json_prompt(message, json_template)
157
+ else:
158
+ prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
159
+
160
+ {message}
161
+
162
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
163
+
164
+ """
165
+
166
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
167
+
168
+ if device == "cuda:0":
169
+ model_device = next(model.parameters()).device
170
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
171
+
172
+ with torch.no_grad():
173
+ outputs = model.generate(
174
+ **inputs,
175
+ max_new_tokens=4096,
176
+ temperature=temperature,
177
+ top_p=0.95,
178
+ do_sample=True,
179
+ num_beams=1,
180
+ pad_token_id=tokenizer.eos_token_id,
181
+ eos_token_id=tokenizer.eos_token_id,
182
+ early_stopping=False,
183
+ repetition_penalty=1.1
184
+ )
185
+
186
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
187
+
188
+ if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
189
+ response = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
190
+ else:
191
+ response = generated_text[len(prompt):].strip()
192
+
193
+ if json_mode and response:
194
+ response = prettify_json_response(response)
195
+
196
+ history.append({"role": "user", "content": message})
197
+ history.append({"role": "assistant", "content": response})
198
+
199
+ except Exception as e:
200
+ logger.error(f"Error in chat: {str(e)}")
201
+ history.append({"role": "user", "content": message})
202
+ history.append({"role": "assistant", "content": f"Error: {str(e)}"})
203
+
204
+ return history, ""
205
+
206
+ def clear_chat():
207
+ return [], ""
208
+
209
+ css = """
210
+ .gradio-container {
211
+ max-width: 100% !important;
212
+ width: 100% !important;
213
+ margin: 0 !important;
214
+ padding: 20px !important;
215
+ }
216
+ #chatbot {
217
+ height: 70vh !important;
218
+ min-height: 600px !important;
219
+ overflow-y: auto !important;
220
+ }
221
+ """
222
+
223
+ with gr.Blocks(css=css, title="Llama Chat", theme=gr.themes.Soft()) as demo:
224
+ gr.Markdown(
225
+ """
226
+ # 🦙 Llama Chat
227
+ ### Raw interface for Llama-3.1-8B-Instruct with JSON Mode
228
+
229
+ **JSON Response Mode**: Enable for structured outputs!
230
+ - 🎯 **General**: Basic structured responses
231
+ - ❓ **Questions**: Generate question sets from content
232
+ """
233
+ )
234
+
235
+ chatbot = gr.Chatbot(
236
+ elem_id="chatbot",
237
+ label="Chat",
238
+ show_label=False,
239
+ avatar_images=(None, None),
240
+ show_share_button=False,
241
+ type="messages",
242
+ height=600,
243
+ render_markdown=True,
244
+ show_copy_button=True
245
+ )
246
+
247
+ with gr.Row():
248
+ with gr.Column(scale=4):
249
+ msg = gr.Textbox(
250
+ placeholder="Type your message here...",
251
+ show_label=False,
252
+ container=False
253
+ )
254
+ with gr.Column(scale=1):
255
+ submit_btn = gr.Button("Send", variant="primary")
256
+ with gr.Column(scale=1):
257
+ clear_btn = gr.Button("Clear", variant="secondary")
258
+
259
+ with gr.Row():
260
+ temperature = gr.Slider(
261
+ minimum=0.1,
262
+ maximum=2.0,
263
+ value=0.8,
264
+ step=0.1,
265
+ label="Temperature"
266
+ )
267
+
268
+ with gr.Row():
269
+ with gr.Column(scale=2):
270
+ json_mode = gr.Checkbox(
271
+ label="JSON Response Mode",
272
+ value=False,
273
+ info="Get structured JSON responses"
274
+ )
275
+ with gr.Column(scale=3):
276
+ json_template = gr.Dropdown(
277
+ choices=["general", "questions"],
278
+ value="general",
279
+ label="JSON Template",
280
+ visible=False
281
+ )
282
+
283
+ def respond(message, history, temp, json_enabled, json_type):
284
+ return chat_with_model(message, history, temp, json_enabled, json_type)
285
+
286
+ def toggle_json_template(json_enabled):
287
+ return gr.update(visible=json_enabled)
288
+
289
+ json_mode.change(toggle_json_template, inputs=[json_mode], outputs=[json_template])
290
+
291
+ msg.submit(respond, [msg, chatbot, temperature, json_mode, json_template], [chatbot, msg])
292
+ submit_btn.click(respond, [msg, chatbot, temperature, json_mode, json_template], [chatbot, msg])
293
+ clear_btn.click(clear_chat, outputs=[chatbot, msg])
294
+
295
+ if __name__ == "__main__":
296
+ demo.launch(
297
+ server_name="0.0.0.0",
298
+ server_port=7860,
299
+ share=False,
300
+ show_error=True
301
+ )