GeminiFan207 commited on
Commit
1fa0784
·
verified ·
1 Parent(s): 5fd1ba1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -123
app.py CHANGED
@@ -1,158 +1,337 @@
1
  import gradio as gr
2
  import requests
3
  import json
 
 
 
 
 
 
 
 
4
 
5
- # Default model
6
- MODEL = "ZeppelinCorp/Charm_15"
 
7
 
8
- # Maximum number of messages to keep in chat history
9
- MAX_HISTORY_MESSAGES = 5
 
10
 
11
- # Generate response using Hugging Face Inference API
12
- def generate_response(prompt, model_name, temperature, top_k, top_p, chat_history):
13
- global MODEL
14
-
15
- if model_name != MODEL:
16
- MODEL = model_name
17
-
18
- # Truncate chat history to the last MAX_HISTORY_MESSAGES messages
19
- if len(chat_history) > MAX_HISTORY_MESSAGES:
20
- chat_history = chat_history[-MAX_HISTORY_MESSAGES:]
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Combine chat history into a single prompt
23
- full_prompt = "\n".join([f"User: {msg['content']}" if msg['role'] == 'user' else f"Bot: {msg['content']}" for msg in chat_history])
24
- full_prompt += f"\nUser: {prompt}\nBot:"
25
 
26
- API_URL = f"https://api-inference.huggingface.co/models/{MODEL}"
27
- payload = {
28
- "inputs": full_prompt,
29
- "parameters": {
30
- "max_length": 50,
31
- "num_return_sequences": 1,
32
- "temperature": temperature,
33
- "top_k": top_k,
34
- "top_p": top_p
35
- }
36
  }
37
 
38
  try:
39
- response = requests.post(API_URL, json=payload)
40
- if response.status_code == 200:
41
- generated_text = response.json()[0]['generated_text']
42
- chat_history.append({"role": "user", "content": prompt})
43
- chat_history.append({"role": "assistant", "content": generated_text})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  else:
45
- chat_history.append({"role": "assistant", "content": f"Error: {response.status_code} - {response.text}"})
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- chat_history.append({"role": "assistant", "content": f"Error generating response: {e}"})
48
-
49
- return chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # File upload functions
52
- def upload_image(image):
53
- return f"**Image Uploaded:** {image.name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- def upload_audio(audio):
56
- return f"**Audio Uploaded:** {audio.name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Reasoning analysis
59
- def analyze_reasoning(input_text):
60
- return f"**Reasoning Analysis:** Analyzing input: {input_text}"
 
 
 
 
 
 
 
 
 
 
61
 
62
- # Switch model function
63
- def switch_model(new_model):
64
- global MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- MODEL = new_model
67
- return f"**Model Switched:** {new_model}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Clear chat function
70
- def clear_chat():
71
- return []
72
 
73
- # Custom CSS for better UI
74
- custom_css = """
75
- .gradio-container {
76
- font-family: Arial, sans-serif;
77
- }
78
- h1 {
79
- color: #4CAF50;
80
- }
81
- .chat-box {
82
- height: 400px;
83
- overflow-y: auto;
84
- border: 1px solid #ccc;
85
- padding: 10px;
86
- border-radius: 5px;
87
- background-color: #f9f9f9;
88
- }
89
- .user-message {
90
- background-color: #e3f2fd;
91
- padding: 8px;
92
- border-radius: 5px;
93
- margin-bottom: 5px;
94
- max-width: 70%;
95
- margin-left: auto;
96
- }
97
- .bot-message {
98
- background-color: #f5f5f5;
99
- padding: 8px;
100
- border-radius: 5px;
101
- margin-bottom: 5px;
102
- max-width: 70%;
103
- margin-right: auto;
104
- }
105
- """
106
 
107
- # Interface setup
108
- with gr.Blocks(css=custom_css) as demo:
109
- gr.Markdown("# Chatbot UI")
110
 
111
  with gr.Tab("Chat"):
112
  # Chat history display
113
- chat_history = gr.Chatbot(label="Chat", elem_classes="chat-box", type="messages")
114
 
115
  with gr.Row():
116
  with gr.Column():
117
- prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=2)
118
- model_dropdown = gr.Dropdown(
119
- ["ZeppelinCorp/Charm_15", "ZeppelinCorp/Smartbloom_1.1", "gpt2", "EleutherAI/gpt-neo-125M"],
120
- label="Select Model",
121
- value=MODEL
122
- )
123
- temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature")
124
- top_k = gr.Slider(minimum=1, maximum=100, value=50, label="Top-K")
125
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-P")
126
  generate_button = gr.Button("Send")
127
- clear_button = gr.Button("Clear Chat")
128
 
129
  # Chat interaction
130
  generate_button.click(
131
- generate_response,
132
- inputs=[prompt_input, model_dropdown, temperature, top_k, top_p, chat_history],
133
- outputs=chat_history
134
  )
135
- clear_button.click(clear_chat, inputs=None, outputs=chat_history)
136
-
137
- with gr.Tab("Upload Image"):
138
- image_input = gr.File(label="Upload Image", file_types=["image"])
139
- image_output = gr.Textbox(label="Upload Status", interactive=False)
140
- image_input.change(upload_image, inputs=image_input, outputs=image_output)
141
-
142
- with gr.Tab("Upload Audio"):
143
- audio_input = gr.File(label="Upload Audio", file_types=["audio"])
144
- audio_output = gr.Textbox(label="Upload Status", interactive=False)
145
- audio_input.change(upload_audio, inputs=audio_input, outputs=audio_output)
146
-
147
- with gr.Tab("Reasoning Analysis"):
148
- reasoning_input = gr.Textbox(label="Enter text for reasoning analysis", placeholder="Type your text here...")
149
- reasoning_output = gr.Textbox(label="Analysis Result", interactive=False)
150
- reasoning_input.change(analyze_reasoning, inputs=reasoning_input, outputs=reasoning_output)
151
-
152
- with gr.Tab("Switch Model"):
153
- model_switch_input = gr.Textbox(label="Enter new model name", placeholder="Type the model name...")
154
- model_switch_output = gr.Textbox(label="Switch Status", interactive=False)
155
- model_switch_input.change(switch_model, inputs=model_switch_input, outputs=model_switch_output)
156
 
157
  # Launch the interface
158
  demo.launch()
 
1
  import gradio as gr
2
  import requests
3
  import json
4
+ import os
5
+ import base64
6
+ from PIL import Image
7
+ import soundfile as sf
8
+ import mimetypes
9
+ import logging
10
+ from io import BytesIO
11
+ import tempfile
12
 
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
+ logger = logging.getLogger(__name__)
16
 
17
+ # Hugging Face API configuration
18
+ HF_API_URL = os.getenv("HF_API_URL") # URL of your Hugging Face model endpoint
19
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Hugging Face API token
20
 
21
+ # Default parameter values
22
+ default_max_tokens = 4096
23
+ default_temperature = 0.7
24
+ default_top_p = 0.9
25
+ default_presence_penalty = 0.0
26
+ default_frequency_penalty = 0.0
27
+
28
+ # Initialize MIME types
29
+ mimetypes.init()
30
+
31
+ def call_hf_endpoint(payload, api_url, api_token, params=None):
32
+ """Call Hugging Face Inference API with the given payload."""
33
+ # Set parameters from the UI inputs or use defaults
34
+ if params is None:
35
+ params = {
36
+ "max_tokens": default_max_tokens,
37
+ "temperature": default_temperature,
38
+ "top_p": default_top_p,
39
+ "presence_penalty": default_presence_penalty,
40
+ "frequency_penalty": default_frequency_penalty
41
+ }
42
 
43
+ # Add parameters to the payload
44
+ if "parameters" not in payload:
45
+ payload["parameters"] = params
46
 
47
+ # Set up headers
48
+ headers = {
49
+ "Authorization": f"Bearer {api_token}",
50
+ "Content-Type": "application/json"
 
 
 
 
 
 
51
  }
52
 
53
  try:
54
+ logger.info(f"Sending request to {api_url}")
55
+ logger.info(f"Using parameters: {params}")
56
+ response = requests.post(api_url, headers=headers, json=payload)
57
+ response.raise_for_status()
58
+ result = response.json()
59
+ logger.info("Received response successfully")
60
+ return result
61
+ except requests.exceptions.HTTPError as error:
62
+ logger.error(f"Request failed with status code: {error.response.status_code}")
63
+ logger.error(f"Error message: {error.response.text}")
64
+ return {"error": error.response.text}
65
+
66
+ def improved_fetch_audio_from_url(url):
67
+ """Fetch audio data from URL and convert to base64."""
68
+ try:
69
+ logger.info(f"Fetching audio from URL: {url}")
70
+ response = requests.get(url, timeout=30)
71
+ response.raise_for_status()
72
+
73
+ # Determine MIME type based on URL
74
+ file_extension = os.path.splitext(url)[1].lower()
75
+ mime_type = None
76
+
77
+ if file_extension == '.wav':
78
+ mime_type = "audio/wav"
79
+ elif file_extension == '.mp3':
80
+ mime_type = "audio/mpeg"
81
+ elif file_extension == '.flac':
82
+ mime_type = "audio/flac"
83
+ elif file_extension in ['.m4a', '.aac']:
84
+ mime_type = "audio/aac"
85
+ elif file_extension == '.ogg':
86
+ mime_type = "audio/ogg"
87
  else:
88
+ # Try to detect the MIME type from headers
89
+ content_type = response.headers.get('Content-Type', '')
90
+ if content_type.startswith('audio/'):
91
+ mime_type = content_type
92
+ else:
93
+ mime_type = "audio/wav" # Default to WAV
94
+
95
+ logger.info(f"Detected MIME type: {mime_type}")
96
+
97
+ # Convert to base64
98
+ base64_audio = base64.b64encode(response.content).decode('utf-8')
99
+ logger.info(f"Successfully encoded audio to base64, length: {len(base64_audio)}")
100
+
101
+ return mime_type, base64_audio
102
  except Exception as e:
103
+ logger.error(f"Error fetching audio from URL: {e}", exc_info=True)
104
+ return None, None
105
+
106
+ def fetch_image_from_url(url):
107
+ """Fetch image data from URL and convert to base64."""
108
+ try:
109
+ logger.info(f"Fetching image from URL: {url}")
110
+ response = requests.get(url)
111
+ response.raise_for_status()
112
+
113
+ # Determine MIME type based on URL
114
+ file_extension = os.path.splitext(url)[1].lower()
115
+ if file_extension in ['.jpg', '.jpeg']:
116
+ mime_type = "image/jpeg"
117
+ elif file_extension == '.png':
118
+ mime_type = "image/png"
119
+ elif file_extension == '.gif':
120
+ mime_type = "image/gif"
121
+ elif file_extension in ['.bmp', '.tiff', '.webp']:
122
+ mime_type = f"image/{file_extension[1:]}"
123
+ else:
124
+ mime_type = "image/jpeg" # Default to JPEG
125
+
126
+ # Convert to base64
127
+ base64_image = base64.b64encode(response.content).decode('utf-8')
128
+
129
+ logger.info(f"Successfully fetched and encoded image, mime type: {mime_type}")
130
+ return mime_type, base64_image
131
+ except Exception as e:
132
+ logger.error(f"Error fetching image from URL: {e}")
133
+ return None, None
134
 
135
+ def encode_base64_from_file(file_path):
136
+ """Encode file content to base64 string and determine MIME type."""
137
+ file_extension = os.path.splitext(file_path)[1].lower()
138
+
139
+ # Map file extensions to MIME types
140
+ if file_extension in ['.jpg', '.jpeg']:
141
+ mime_type = "image/jpeg"
142
+ elif file_extension == '.png':
143
+ mime_type = "image/png"
144
+ elif file_extension == '.gif':
145
+ mime_type = "image/gif"
146
+ elif file_extension in ['.bmp', '.tiff', '.webp']:
147
+ mime_type = f"image/{file_extension[1:]}"
148
+ elif file_extension == '.flac':
149
+ mime_type = "audio/flac"
150
+ elif file_extension == '.wav':
151
+ mime_type = "audio/wav"
152
+ elif file_extension == '.mp3':
153
+ mime_type = "audio/mpeg"
154
+ elif file_extension in ['.m4a', '.aac']:
155
+ mime_type = "audio/aac"
156
+ elif file_extension == '.ogg':
157
+ mime_type = "audio/ogg"
158
+ else:
159
+ mime_type = "application/octet-stream"
160
+
161
+ # Read and encode file content
162
+ with open(file_path, "rb") as file:
163
+ encoded_string = base64.b64encode(file.read()).decode('utf-8')
164
+
165
+ return encoded_string, mime_type
166
 
167
+ def process_message(history, message, conversation_state):
168
+ """Process user message and update both history and internal state."""
169
+ # Extract text and files
170
+ text_content = message["text"] if message["text"] else ""
171
+
172
+ image_files = []
173
+ audio_files = []
174
+
175
+ # Create content array for internal state
176
+ content_items = []
177
+
178
+ # Add text if available
179
+ if text_content:
180
+ content_items.append({"type": "text", "text": text_content})
181
+
182
+ # Process and immediately convert files to base64
183
+ if message["files"] and len(message["files"]) > 0:
184
+ for file_path in message["files"]:
185
+ file_extension = os.path.splitext(file_path)[1].lower()
186
+ file_name = os.path.basename(file_path)
187
+
188
+ # Convert the file to base64 immediately
189
+ base64_content, mime_type = encode_base64_from_file(file_path)
190
+
191
+ # Add to content items for the API
192
+ if mime_type.startswith("image/"):
193
+ content_items.append({
194
+ "type": "image_url",
195
+ "image_url": {
196
+ "url": f"data:{mime_type};base64,{base64_content}"
197
+ }
198
+ })
199
+ image_files.append(file_path)
200
+ elif mime_type.startswith("audio/"):
201
+ content_items.append({
202
+ "type": "audio_url",
203
+ "audio_url": {
204
+ "url": f"data:{mime_type};base64,{base64_content}"
205
+ }
206
+ })
207
+ audio_files.append(file_path)
208
+
209
+ # Only proceed if we have content
210
+ if content_items:
211
+ # Add to Gradio chatbot history (for display)
212
+ history.append({"role": "user", "content": text_content})
213
 
214
+ # Add file messages if present
215
+ for file_path in image_files + audio_files:
216
+ history.append({"role": "user", "content": {"path": file_path}})
217
+
218
+ logger.info(f"Updated history with user message. Current conversation has {len(image_files)} images and {len(audio_files)} audio files")
219
+
220
+ # Add to internal conversation state (with base64 data)
221
+ conversation_state.append({
222
+ "role": "user",
223
+ "content": content_items
224
+ })
225
+
226
+ return history, gr.MultimodalTextbox(value=None, interactive=False), conversation_state
227
 
228
+ def bot_response(history, conversation_state):
229
+ """Generate bot response based on conversation state."""
230
+ if not conversation_state:
231
+ return history, conversation_state
232
+
233
+ # Create the payload
234
+ payload = {
235
+ "inputs": conversation_state
236
+ }
237
+
238
+ # Log the payload for debugging (without base64 data)
239
+ debug_payload = json.loads(json.dumps(payload))
240
+ for item in debug_payload["inputs"]:
241
+ if "content" in item and isinstance(item["content"], list):
242
+ for content_item in item["content"]:
243
+ if "image_url" in content_item:
244
+ parts = content_item["image_url"]["url"].split(",")
245
+ if len(parts) > 1:
246
+ content_item["image_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]"
247
+ if "audio_url" in content_item:
248
+ parts = content_item["audio_url"]["url"].split(",")
249
+ if len(parts) > 1:
250
+ content_item["audio_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]"
251
+
252
+ logger.info(f"Sending payload: {json.dumps(debug_payload, indent=2)}")
253
+
254
+ # Call Hugging Face Inference API
255
+ response = call_hf_endpoint(payload, HF_API_URL, HF_API_TOKEN)
256
 
257
+ # Extract text response from the Hugging Face API response
258
+ try:
259
+ if isinstance(response, dict):
260
+ if "generated_text" in response:
261
+ result = response["generated_text"]
262
+ elif "error" in response:
263
+ result = f"Error: {response['error']}"
264
+ else:
265
+ result = f"Received response: {json.dumps(response)}"
266
+ else:
267
+ result = str(response)
268
+ except Exception as e:
269
+ result = f"Error processing response: {str(e)}"
270
+
271
+ # Add bot response to history
272
+ history.append({"role": "assistant", "content": result})
273
+
274
+ # Add to conversation state
275
+ conversation_state.append({
276
+ "role": "assistant",
277
+ "content": [{"type": "text", "text": result}]
278
+ })
279
+
280
+ return history, conversation_state
281
 
282
+ def enable_input():
283
+ """Re-enable the input box after bot responds."""
284
+ return gr.MultimodalTextbox(interactive=True)
285
 
286
+ def update_debug(conversation_state):
287
+ """Update debug output with the last payload that would be sent."""
288
+ if not conversation_state:
289
+ return {}
290
+
291
+ # Create a payload from the conversation
292
+ payload = {
293
+ "inputs": conversation_state
294
+ }
295
+
296
+ # Remove base64 data to avoid cluttering the UI
297
+ sanitized_payload = json.loads(json.dumps(payload))
298
+ for item in sanitized_payload["inputs"]:
299
+ if "content" in item and isinstance(item["content"], list):
300
+ for content_item in item["content"]:
301
+ if "image_url" in content_item:
302
+ parts = content_item["image_url"]["url"].split(",")
303
+ if len(parts) > 1:
304
+ content_item["image_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]"
305
+ if "audio_url" in content_item:
306
+ parts = content_item["audio_url"]["url"].split(",")
307
+ if len(parts) > 1:
308
+ content_item["audio_url"]["url"] = parts[0] + ",[BASE64_DATA_REMOVED]"
309
+
310
+ return sanitized_payload
 
 
 
 
 
 
 
 
311
 
312
+ # Gradio interface setup
313
+ with gr.Blocks() as demo:
314
+ gr.Markdown("# Chatbot with Hugging Face Models")
315
 
316
  with gr.Tab("Chat"):
317
  # Chat history display
318
+ chat_history = gr.Chatbot(label="Chat")
319
 
320
  with gr.Row():
321
  with gr.Column():
322
+ prompt_input = gr.MultimodalTextbox(label="Enter your prompt", placeholder="Type your message here...", lines=2)
 
 
 
 
 
 
 
 
323
  generate_button = gr.Button("Send")
 
324
 
325
  # Chat interaction
326
  generate_button.click(
327
+ process_message,
328
+ inputs=[chat_history, prompt_input],
329
+ outputs=[chat_history, prompt_input]
330
  )
331
+
332
+ # Debug output
333
+ debug_output = gr.JSON(label="Debug Output")
334
+ demo.load(update_debug, inputs=[chat_history], outputs=debug_output, every=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  # Launch the interface
337
  demo.launch()