david167 commited on
Commit
fcef7cd
·
1 Parent(s): 8106bb9

COMPLETE REWRITE: Clean ChatGPT-style interface with proper response handling

Browse files
Files changed (1) hide show
  1. gradio_app.py +202 -185
gradio_app.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import logging
3
- import time
4
- import asyncio
5
- from typing import List, Optional, Dict, Any
6
  import threading
 
7
 
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -13,233 +11,252 @@ import gradio as gr
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- class ModelManager:
17
- def __init__(self):
18
- self.model = None
19
- self.tokenizer = None
20
- self.device = None
21
- self.model_loaded = False
22
-
23
- # Load model in a separate thread to avoid blocking
24
- self.load_thread = threading.Thread(target=self.load_model_sync)
25
- self.load_thread.daemon = True
26
- self.load_thread.start()
27
-
28
- def load_model_sync(self):
29
- """Load model synchronously in background thread"""
30
- try:
31
- logger.info("Starting model loading...")
32
-
33
- # Check if CUDA is available and force to cuda:0
34
- if torch.cuda.is_available():
35
- torch.cuda.set_device(0)
36
- self.device = "cuda:0"
37
- else:
38
- self.device = "cpu"
39
-
40
- logger.info(f"Using device: {self.device}")
41
-
42
- if self.device == "cuda:0":
43
- logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
44
- logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
45
-
46
- # Get HF token from environment
47
- hf_token = os.getenv("HF_TOKEN")
48
-
49
- logger.info("Loading Llama-3.1-8B-Instruct model...")
50
- base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
51
-
52
- self.tokenizer = AutoTokenizer.from_pretrained(
53
- base_model_name,
54
- use_fast=True,
55
- trust_remote_code=True,
56
- token=hf_token
57
- )
58
-
59
- self.model = AutoModelForCausalLM.from_pretrained(
60
- base_model_name,
61
- torch_dtype=torch.float16 if self.device == "cuda:0" else torch.float32,
62
- device_map={"": 0}, # Force all parameters to GPU 0
63
- trust_remote_code=True,
64
- low_cpu_mem_usage=True,
65
- use_safetensors=True,
66
- token=hf_token
67
- )
68
-
69
- # Ensure model is on the correct device
70
- if self.device == "cuda:0":
71
- self.model = self.model.to(self.device)
72
-
73
- self.model_loaded = True
74
- logger.info("Model loaded successfully!")
75
-
76
- except Exception as e:
77
- logger.error(f"Error loading model: {str(e)}")
78
- self.model_loaded = False
79
 
80
- # Initialize model manager
81
- model_manager = ModelManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def chat_with_model(message, history, temperature):
84
- """Raw chat function for direct model interaction"""
 
 
 
 
 
 
85
  if not message.strip():
86
  return history, ""
87
 
88
  try:
89
- # Use the model directly for raw chat
90
- if not model_manager.model_loaded:
91
- response = "Model not loaded yet. Please wait..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  else:
93
- # Create a simple chat prompt
94
- prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
95
-
96
- {message}
97
-
98
- <|eot_id|><|start_header_id|>assistant<|end_header_id|>
99
-
100
- """
101
-
102
- # Generate response using the model directly
103
- inputs = model_manager.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
104
-
105
- # Force all inputs to the same device as the model
106
- if model_manager.device == "cuda:0":
107
- model_device = next(model_manager.model.parameters()).device
108
- inputs = {k: v.to(model_device) for k, v in inputs.items()}
109
-
110
- with torch.no_grad():
111
- outputs = model_manager.model.generate(
112
- **inputs,
113
- max_new_tokens=4096,
114
- temperature=temperature,
115
- top_p=0.95,
116
- do_sample=True,
117
- num_beams=1,
118
- pad_token_id=model_manager.tokenizer.eos_token_id,
119
- eos_token_id=model_manager.tokenizer.eos_token_id,
120
- early_stopping=False, # Disable early stopping to prevent premature truncation
121
- repetition_penalty=1.1 # Add slight repetition penalty to improve quality
122
- )
123
-
124
- # Decode the generated text and remove the input prompt
125
- full_text = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True)
126
- # Use a more robust method to extract the response
127
- # Look for the assistant header end and extract everything after it
128
- assistant_start = "<|start_header_id|>assistant<|end_header_id|>"
129
- if assistant_start in full_text:
130
- # Find the position after the assistant header
131
- response_start = full_text.find(assistant_start) + len(assistant_start)
132
- # TEMPORARY: Show full response for debugging
133
- response = f"=== FULL RESPONSE ===\n{full_text}\n=== END ==="
134
- # Original line: response = full_text[response_start:].strip()
135
- logger.info(f"Extracted response length: {len(response)}")
136
- else:
137
- # Fallback: try to remove the original prompt
138
- try:
139
- response = full_text[len(prompt):].strip()
140
- except:
141
- response = full_text.strip()
142
-
143
- # Check if response ends abruptly (might indicate truncation)
144
- if response and not response.endswith(('.', '!', '?', ':', ';')):
145
- logger.warning(f"Response may be truncated - ends with: '{response[-20:]}'")
146
-
147
- if not response:
148
- response = "I couldn't generate a response. Please try a different prompt."
149
-
150
  except Exception as e:
151
- logger.error(f"Error in chat: {str(e)}")
152
- response = f"Error: {str(e)}"
153
-
154
- # Add both user message and bot response to history using new message format
155
- history.append({"role": "user", "content": message})
156
- history.append({"role": "assistant", "content": response})
157
 
158
  return history, ""
159
 
160
- def clear_chat():
161
  """Clear the chat history"""
162
- return [], ""
 
 
 
 
 
 
 
163
 
164
- # CSS for styling
165
  css = """
 
 
 
 
166
  #chatbot {
167
- min-height: 500px;
168
- border: 1px solid #e0e0e0;
169
- border-radius: 8px;
170
- overflow: auto;
171
- background-color: #f9f9f9;
172
  }
173
- .gradio-container {
174
- font-family: 'IBM Plex Sans', sans-serif;
 
 
 
 
 
 
 
 
 
 
175
  }
176
  """
177
 
178
- # Create simplified chat interface
179
- with gr.Blocks(css=css, title="Llama Chat", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
180
  gr.Markdown(
181
  """
182
  # 🦙 Llama Chat
183
- ### Raw interface for Llama-3.1-8B-Instruct
184
 
185
- Direct chat interface for testing prompts and having conversations with the model.
186
  """
187
  )
188
 
189
- # Simple chat interface
190
- chatbot = gr.Chatbot(
191
- elem_id="chatbot",
192
- label="Chat",
193
- show_label=False,
194
- avatar_images=(None, None),
195
- show_share_button=False,
196
- type="messages", # Use new message format
197
- height=500
198
- )
199
-
200
  with gr.Row():
201
  with gr.Column(scale=4):
202
- msg = gr.Textbox(
203
- placeholder="Type your message here...",
204
  show_label=False,
205
- container=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  )
207
- with gr.Column(scale=1):
208
- submit_btn = gr.Button("Send", variant="primary")
209
- with gr.Column(scale=1):
210
- clear_btn = gr.Button("Clear", variant="secondary")
211
-
212
- with gr.Row():
213
- temperature = gr.Slider(
214
- minimum=0.1,
215
- maximum=2.0,
216
- value=0.8,
217
- step=0.1,
218
- label="Temperature",
219
- info="Controls randomness (0.1=focused, 2.0=creative)"
220
- )
221
 
222
  # Event handlers
223
  def respond(message, history, temp):
224
- return chat_with_model(message, history, temp)
225
 
 
226
  msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
227
- submit_btn.click(respond, [msg, chatbot, temperature], [chatbot, msg])
228
- clear_btn.click(clear_chat, outputs=[chatbot, msg])
229
 
230
- # Add footer
231
  gr.Markdown(
232
  """
233
  ---
234
  <div style="text-align: center; color: #666; font-size: 0.9em;">
235
- Built with ❤️ using Gradio and Llama-3.1-8B-Instruct
236
- <a href="/docs" target="_blank">API Documentation</a>
237
  </div>
238
  """
239
  )
240
 
241
  if __name__ == "__main__":
242
- # Launch Gradio interface
243
  demo.launch(
244
  server_name="0.0.0.0",
245
  server_port=7860,
 
1
  import os
2
  import logging
 
 
 
3
  import threading
4
+ from typing import List, Tuple
5
 
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Global variables for model
15
+ model = None
16
+ tokenizer = None
17
+ device = None
18
+ model_loaded = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def load_model():
21
+ """Load the Llama model and tokenizer"""
22
+ global model, tokenizer, device, model_loaded
23
+
24
+ try:
25
+ logger.info("Starting model loading...")
26
+
27
+ # Check if CUDA is available and force to cuda:0
28
+ if torch.cuda.is_available():
29
+ torch.cuda.set_device(0)
30
+ device = "cuda:0"
31
+ else:
32
+ device = "cpu"
33
+
34
+ logger.info(f"Using device: {device}")
35
+
36
+ if device == "cuda:0":
37
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
38
+ logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
39
+
40
+ # Get HF token from environment
41
+ hf_token = os.getenv("HF_TOKEN")
42
+
43
+ logger.info("Loading Llama-3.1-8B-Instruct model...")
44
+ model_name = "meta-llama/Llama-3.1-8B-Instruct"
45
+
46
+ # Load tokenizer
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ model_name,
49
+ use_fast=True,
50
+ trust_remote_code=True,
51
+ token=hf_token
52
+ )
53
+
54
+ # Load model
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_name,
57
+ torch_dtype=torch.float16 if device == "cuda:0" else torch.float32,
58
+ device_map={"": 0}, # Force all parameters to GPU 0
59
+ trust_remote_code=True,
60
+ low_cpu_mem_usage=True,
61
+ use_safetensors=True,
62
+ token=hf_token
63
+ )
64
+
65
+ # Ensure model is on the correct device
66
+ if device == "cuda:0":
67
+ model = model.to(device)
68
+
69
+ model_loaded = True
70
+ logger.info("Model loaded successfully!")
71
+
72
+ except Exception as e:
73
+ logger.error(f"Error loading model: {str(e)}")
74
+ model_loaded = False
75
 
76
+ def chat_response(message: str, history: List[List[str]], temperature: float) -> Tuple[List[List[str]], str]:
77
+ """Generate a response to the user's message"""
78
+ global model, tokenizer, device, model_loaded
79
+
80
+ if not model_loaded:
81
+ history.append([message, "🔄 Model is still loading, please wait..."])
82
+ return history, ""
83
+
84
  if not message.strip():
85
  return history, ""
86
 
87
  try:
88
+ # Create Llama chat prompt
89
+ conversation = ""
90
+ for user_msg, assistant_msg in history:
91
+ if user_msg and assistant_msg:
92
+ conversation += f"<|start_header_id|>user<|end_header_id|>\n{user_msg}<|eot_id|>"
93
+ conversation += f"<|start_header_id|>assistant<|end_header_id|>\n{assistant_msg}<|eot_id|>"
94
+
95
+ # Add current message
96
+ prompt = f"<|begin_of_text|>{conversation}<|start_header_id|>user<|end_header_id|>\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
97
+
98
+ # Tokenize input
99
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
100
+
101
+ # Move to correct device
102
+ if device == "cuda:0":
103
+ inputs = {k: v.to(device) for k, v in inputs.items()}
104
+
105
+ # Generate response
106
+ with torch.no_grad():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=2048,
110
+ temperature=temperature,
111
+ top_p=0.95,
112
+ do_sample=True,
113
+ pad_token_id=tokenizer.eos_token_id,
114
+ eos_token_id=tokenizer.eos_token_id
115
+ )
116
+
117
+ # Decode response
118
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
119
+
120
+ # Extract just the assistant's response (everything after the last assistant header)
121
+ response_start = generated_text.rfind("<|start_header_id|>assistant<|end_header_id|>")
122
+ if response_start != -1:
123
+ response = generated_text[response_start + len("<|start_header_id|>assistant<|end_header_id|>"):].strip()
124
  else:
125
+ # Fallback: remove the original prompt
126
+ response = generated_text[len(prompt):].strip()
127
+
128
+ # Clean up any remaining tokens
129
+ response = response.replace("<|eot_id|>", "").strip()
130
+
131
+ if not response:
132
+ response = "I apologize, but I couldn't generate a response. Please try rephrasing your message."
133
+
134
+ # Add to history
135
+ history.append([message, response])
136
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  except Exception as e:
138
+ logger.error(f"Error generating response: {str(e)}")
139
+ history.append([message, f"Error: {str(e)}"])
 
 
 
 
140
 
141
  return history, ""
142
 
143
+ def clear_history():
144
  """Clear the chat history"""
145
+ return []
146
+
147
+ # Load model in background thread
148
+ def load_model_background():
149
+ load_model()
150
+
151
+ model_thread = threading.Thread(target=load_model_background, daemon=True)
152
+ model_thread.start()
153
 
154
+ # Custom CSS for ChatGPT-like appearance
155
  css = """
156
+ .gradio-container {
157
+ max-width: 1200px !important;
158
+ margin: auto !important;
159
+ }
160
  #chatbot {
161
+ height: 600px !important;
162
+ overflow-y: auto !important;
 
 
 
163
  }
164
+ .message {
165
+ padding: 10px !important;
166
+ margin: 5px 0 !important;
167
+ border-radius: 10px !important;
168
+ }
169
+ .user {
170
+ background-color: #dcf8c6 !important;
171
+ margin-left: 20% !important;
172
+ }
173
+ .bot {
174
+ background-color: #f1f1f1 !important;
175
+ margin-right: 20% !important;
176
  }
177
  """
178
 
179
+ # Create Gradio interface
180
+ with gr.Blocks(
181
+ css=css,
182
+ title="Llama Chat",
183
+ theme=gr.themes.Soft()
184
+ ) as demo:
185
+
186
+ # Header
187
  gr.Markdown(
188
  """
189
  # 🦙 Llama Chat
190
+ ### Powered by Llama-3.1-8B-Instruct
191
 
192
+ A clean, ChatGPT-style interface for conversing with the Llama model.
193
  """
194
  )
195
 
196
+ # Chat interface
 
 
 
 
 
 
 
 
 
 
197
  with gr.Row():
198
  with gr.Column(scale=4):
199
+ chatbot = gr.Chatbot(
200
+ label="Chat",
201
  show_label=False,
202
+ height=600,
203
+ show_copy_button=True
204
+ )
205
+
206
+ with gr.Row():
207
+ msg = gr.Textbox(
208
+ placeholder="Type your message here...",
209
+ show_label=False,
210
+ scale=4,
211
+ lines=1,
212
+ max_lines=5
213
+ )
214
+ send_btn = gr.Button("Send", variant="primary", scale=1)
215
+
216
+ with gr.Column(scale=1, min_width=250):
217
+ gr.Markdown("### ⚙️ Settings")
218
+
219
+ temperature = gr.Slider(
220
+ minimum=0.1,
221
+ maximum=2.0,
222
+ value=0.8,
223
+ step=0.1,
224
+ label="Temperature",
225
+ info="Controls creativity (0.1=focused, 2.0=creative)"
226
+ )
227
+
228
+ clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary")
229
+
230
+ gr.Markdown(
231
+ """
232
+ ### 💡 Tips
233
+ - Use lower temperature (0.1-0.5) for factual responses
234
+ - Use higher temperature (1.0-2.0) for creative tasks
235
+ - Press Enter to send messages
236
+ - The model maintains conversation context
237
+ """
238
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  # Event handlers
241
  def respond(message, history, temp):
242
+ return chat_response(message, history, temp)
243
 
244
+ # Connect events
245
  msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
246
+ send_btn.click(respond, [msg, chatbot, temperature], [chatbot, msg])
247
+ clear_btn.click(lambda: (clear_history(), ""), outputs=[chatbot, msg])
248
 
249
+ # Footer
250
  gr.Markdown(
251
  """
252
  ---
253
  <div style="text-align: center; color: #666; font-size: 0.9em;">
254
+ 🚀 Built with Gradio 🦙 Powered by Llama-3.1-8B-Instruct
 
255
  </div>
256
  """
257
  )
258
 
259
  if __name__ == "__main__":
 
260
  demo.launch(
261
  server_name="0.0.0.0",
262
  server_port=7860,