david167 commited on
Commit
f093b76
·
1 Parent(s): a7cf970

Add JSON imports for structured response functionality

Browse files
Files changed (1) hide show
  1. gradio_app.py +0 -275
gradio_app.py DELETED
@@ -1,275 +0,0 @@
1
- import os
2
- import logging
3
- import threading
4
- from typing import List, Tuple
5
-
6
- import torch
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- import gradio as gr
9
-
10
- # Configure logging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- # Global variables for model
15
- model = None
16
- tokenizer = None
17
- device = None
18
- model_loaded = False
19
-
20
- def load_model():
21
- """Load the Llama model and tokenizer"""
22
- global model, tokenizer, device, model_loaded
23
-
24
- try:
25
- logger.info("Starting model loading...")
26
-
27
- # Check if CUDA is available and force to cuda:0
28
- if torch.cuda.is_available():
29
- torch.cuda.set_device(0)
30
- device = "cuda:0"
31
- else:
32
- device = "cpu"
33
-
34
- logger.info(f"Using device: {device}")
35
-
36
- if device == "cuda:0":
37
- logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
38
- logger.info(f"VRAM Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
39
-
40
- # Get HF token from environment
41
- hf_token = os.getenv("HF_TOKEN")
42
-
43
- logger.info("Loading Llama-3.1-8B-Instruct model...")
44
- model_name = "meta-llama/Llama-3.1-8B-Instruct"
45
-
46
- # Load tokenizer
47
- tokenizer = AutoTokenizer.from_pretrained(
48
- model_name,
49
- use_fast=True,
50
- trust_remote_code=True,
51
- token=hf_token
52
- )
53
-
54
- # Load model
55
- model = AutoModelForCausalLM.from_pretrained(
56
- model_name,
57
- torch_dtype=torch.float16 if device == "cuda:0" else torch.float32,
58
- device_map={"": 0}, # Force all parameters to GPU 0
59
- trust_remote_code=True,
60
- low_cpu_mem_usage=True,
61
- use_safetensors=True,
62
- token=hf_token
63
- )
64
-
65
- # Ensure model is on the correct device
66
- if device == "cuda:0":
67
- model = model.to(device)
68
-
69
- model_loaded = True
70
- logger.info("Model loaded successfully!")
71
-
72
- except Exception as e:
73
- logger.error(f"Error loading model: {str(e)}")
74
- model_loaded = False
75
-
76
- def chat_response(message: str, history: List[List[str]], temperature: float) -> Tuple[List[List[str]], str]:
77
- """Generate a response to the user's message"""
78
- global model, tokenizer, device, model_loaded
79
-
80
- if not model_loaded:
81
- history.append([message, "🔄 Model is still loading, please wait..."])
82
- return history, ""
83
-
84
- if not message.strip():
85
- return history, ""
86
-
87
- try:
88
- # Create Llama chat prompt
89
- conversation = ""
90
- for user_msg, assistant_msg in history:
91
- if user_msg and assistant_msg:
92
- conversation += f"<|start_header_id|>user<|end_header_id|>\n{user_msg}<|eot_id|>"
93
- conversation += f"<|start_header_id|>assistant<|end_header_id|>\n{assistant_msg}<|eot_id|>"
94
-
95
- # Add current message
96
- prompt = f"<|begin_of_text|>{conversation}<|start_header_id|>user<|end_header_id|>\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
97
-
98
- # Tokenize input
99
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
100
-
101
- # Move to correct device
102
- if device == "cuda:0":
103
- inputs = {k: v.to(device) for k, v in inputs.items()}
104
-
105
- # Generate response
106
- with torch.no_grad():
107
- outputs = model.generate(
108
- **inputs,
109
- max_new_tokens=2048,
110
- temperature=temperature,
111
- top_p=0.95,
112
- do_sample=True,
113
- pad_token_id=tokenizer.eos_token_id,
114
- eos_token_id=tokenizer.eos_token_id
115
- )
116
-
117
- # Decode response
118
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # TEMPORARY: Show complete raw output to debug clipping
119
- response = f"""=== RAW MODEL OUTPUT ===
120
- {generated_text}
121
- === END RAW OUTPUT ===
122
-
123
- === PROMPT USED ===
124
- {prompt}
125
- === END PROMPT ==="""
126
-
127
- # Add to history
128
- # Add to history
129
- history.append([message, response])
130
-
131
- except Exception as e:
132
- logger.error(f"Error generating response: {str(e)}")
133
- history.append([message, f"❌ Error: {str(e)}"])
134
-
135
- return history, ""
136
-
137
- def clear_history():
138
- """Clear the chat history"""
139
- return []
140
-
141
- # Load model in background thread
142
- def load_model_background():
143
- load_model()
144
-
145
- model_thread = threading.Thread(target=load_model_background, daemon=True)
146
- model_thread.start()
147
-
148
- # Custom CSS for ChatGPT-like appearance
149
- css = """
150
- .gradio-container {
151
- max-width: 100%; width: 100% !important;
152
- margin: 0; padding: 20px !important;
153
- }
154
- #chatbot {
155
- height: 70vh; min-height: 600px !important;
156
- overflow-y: auto !important;
157
- }
158
- .message {
159
- padding: 12px 16px !important;
160
- margin: 5px 0 !important;
161
- border-radius: 12px; max-width: 85%; word-wrap: break-word !important;
162
- }
163
- .user {
164
- background-color: #dcf8c6 !important;
165
- margin-left: auto; margin-right: 0 !important;
166
- }
167
- .bot {
168
- background-color: #f1f1f1 !important;
169
- margin-left: 0; margin-right: auto !important;
170
- }/* Responsive design for larger screens */
171
- @media (min-width: 1400px) {
172
- .gradio-container {
173
- padding: 40px !important;
174
- }
175
- #chatbot {
176
- height: 75vh !important;
177
- }
178
- }
179
- @media (min-width: 1800px) {
180
- .gradio-container {
181
- padding: 60px !important;
182
- }
183
- #chatbot {
184
- height: 80vh !important;
185
- }
186
- }}
187
- """
188
-
189
- # Create Gradio interface
190
- with gr.Blocks(
191
- css=css,
192
- title="Llama Chat",
193
- theme=gr.themes.Soft()
194
- ) as demo:
195
-
196
- # Header
197
- gr.Markdown(
198
- """
199
- # 🦙 Llama Chat
200
- ### Powered by Llama-3.1-8B-Instruct
201
-
202
- A clean, ChatGPT-style interface for conversing with the Llama model.
203
- """
204
- )
205
-
206
- # Chat interface
207
- with gr.Row():
208
- with gr.Column(scale=4):
209
- chatbot = gr.Chatbot(
210
- label="Chat",
211
- show_label=False,
212
- height=600,
213
- show_copy_button=True
214
- )
215
-
216
- with gr.Row():
217
- msg = gr.Textbox(
218
- placeholder="Type your message here...",
219
- show_label=False,
220
- scale=4,
221
- lines=1,
222
- max_lines=5
223
- )
224
- send_btn = gr.Button("Send", variant="primary", scale=1)
225
-
226
- with gr.Column(scale=1, min_width=250):
227
- gr.Markdown("### ⚙️ Settings")
228
-
229
- temperature = gr.Slider(
230
- minimum=0.1,
231
- maximum=2.0,
232
- value=0.8,
233
- step=0.1,
234
- label="Temperature",
235
- info="Controls creativity (0.1=focused, 2.0=creative)"
236
- )
237
-
238
- clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary")
239
-
240
- gr.Markdown(
241
- """
242
- ### 💡 Tips
243
- - Use lower temperature (0.1-0.5) for factual responses
244
- - Use higher temperature (1.0-2.0) for creative tasks
245
- - Press Enter to send messages
246
- - The model maintains conversation context
247
- """
248
- )
249
-
250
- # Event handlers
251
- def respond(message, history, temp):
252
- return chat_response(message, history, temp)
253
-
254
- # Connect events
255
- msg.submit(respond, [msg, chatbot, temperature], [chatbot, msg])
256
- send_btn.click(respond, [msg, chatbot, temperature], [chatbot, msg])
257
- clear_btn.click(lambda: (clear_history(), ""), outputs=[chatbot, msg])
258
-
259
- # Footer
260
- gr.Markdown(
261
- """
262
- ---
263
- <div style="text-align: center; color: #666; font-size: 0.9em;">
264
- 🚀 Built with Gradio • 🦙 Powered by Llama-3.1-8B-Instruct
265
- </div>
266
- """
267
- )
268
-
269
- if __name__ == "__main__":
270
- demo.launch(
271
- server_name="0.0.0.0",
272
- server_port=7860,
273
- share=False,
274
- show_error=True
275
- )