Nicolás Larenas commited on
Commit
82afe26
·
verified ·
1 Parent(s): 791ba42

Update gradio_interface.py

Browse files
Files changed (1) hide show
  1. gradio_interface.py +178 -62
gradio_interface.py CHANGED
@@ -1,5 +1,7 @@
 
 
1
  import gradio as gr
2
- from ai_model import query_ai_model
3
  import asyncio
4
  import logging
5
  from config import (
@@ -9,7 +11,9 @@ from config import (
9
  DEFAULT_TOP_K,
10
  )
11
  import os
12
- from typing import List, Optional
 
 
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.ERROR)
@@ -19,89 +23,201 @@ def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
19
  return None
20
  return [sequence.strip() for sequence in stop_sequences.split(",")]
21
 
22
- def submit_message(message, chat_history):
23
- chat_history = chat_history or []
24
- chat_history.append({'role': 'user', 'content': message})
25
- return "", chat_history
26
-
27
- async def gradio_chatbot(
28
- message, chat_history, google_key, temperature, max_output_tokens,
29
- stop_sequences, top_k, top_p
 
 
 
 
 
 
30
  ):
 
 
 
31
  google_key = google_key if google_key else os.environ.get("GOOGLE_API_KEY")
32
  if not google_key:
33
- chat_history.append({'role': 'assistant', 'content': 'GOOGLE_API_KEY is not set. Please provide your API key.'})
34
- return chat_history
 
 
35
 
36
  genai.configure(api_key=google_key)
37
-
38
- stop_sequences_list = preprocess_stop_sequences(stop_sequences)
39
-
40
- try:
41
- # Query the AI model
42
- assistant_reply = await query_ai_model(
43
- message=chat_history[-1]['content'], # Last user message
44
- history=chat_history[:-1], # Exclude the last message
 
 
 
 
 
 
45
  max_output_tokens=max_output_tokens,
46
  temperature=temperature,
47
  top_p=top_p,
48
  top_k=top_k,
49
- stop_sequences=stop_sequences_list,
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
 
52
- # Append assistant's reply to the history
53
- chat_history.append(assistant_reply)
54
- except Exception as e:
55
- logging.error("Error in gradio_chatbot", exc_info=True)
56
- error_message = f"An error occurred: {str(e)}"
57
- chat_history.append({'role': 'assistant', 'content': error_message})
58
-
59
- return chat_history
 
 
 
60
 
61
  def create_gradio_interface():
62
- with gr.Blocks() as demo:
63
- gr.HTML("<h1 align='center'>Gemini Chat Interface 💬</h1>")
64
- gr.HTML("<h2 align='center'>Interact with the Gemini AI model.</h2>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- google_key_component = gr.Textbox(
67
- label="GOOGLE API KEY",
68
- value="",
69
- type="password",
70
- placeholder="Enter your Google API key",
71
- visible=os.environ.get("GOOGLE_API_KEY") is None
72
- )
73
-
74
- chatbot_component = gr.Chatbot(
75
- label='Gemini Chatbot',
76
- height=600,
77
- type='messages'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
 
80
- with gr.Row():
81
- msg = gr.Textbox(
82
- label="Your Message",
83
- placeholder="Type your message here and press Enter"
84
- )
85
-
86
- # Advanced parameters components...
87
-
88
- msg.submit(
89
- fn=submit_message,
90
- inputs=[msg, chatbot_component],
91
- outputs=[msg, chatbot_component]
92
- ).then(
93
- fn=gradio_chatbot,
94
  inputs=[
95
- msg,
96
- chatbot_component,
97
  google_key_component,
 
98
  temperature_component,
99
  max_output_tokens_component,
100
  stop_sequences_component,
101
  top_k_component,
102
- top_p_component
 
103
  ],
104
- outputs=chatbot_component
 
105
  )
106
 
107
  return demo
 
1
+ # gradio_interface.py
2
+
3
  import gradio as gr
4
+ from ai_model import query_ai_model, query_ai_model_vision, preprocess_chat_history
5
  import asyncio
6
  import logging
7
  from config import (
 
11
  DEFAULT_TOP_K,
12
  )
13
  import os
14
+ from typing import List, Dict, Union, Optional
15
+ from PIL import Image
16
+ import time
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.ERROR)
 
23
  return None
24
  return [sequence.strip() for sequence in stop_sequences.split(",")]
25
 
26
+ def user(text_prompt: str, chatbot: List[tuple]):
27
+ if text_prompt:
28
+ chatbot.append((text_prompt, None))
29
+ return "", chatbot
30
+
31
+ async def bot(
32
+ google_key: str,
33
+ files: Optional[List[gr.File]],
34
+ temperature: float,
35
+ max_output_tokens: int,
36
+ stop_sequences: str,
37
+ top_k: int,
38
+ top_p: float,
39
+ chatbot: List[tuple]
40
  ):
41
+ if len(chatbot) == 0:
42
+ return chatbot
43
+
44
  google_key = google_key if google_key else os.environ.get("GOOGLE_API_KEY")
45
  if not google_key:
46
+ raise ValueError(
47
+ "GOOGLE_API_KEY is not set. "
48
+ "Please follow the instructions in the README to set it up."
49
+ )
50
 
51
  genai.configure(api_key=google_key)
52
+ generation_config = genai.types.GenerationConfig(
53
+ temperature=temperature,
54
+ max_output_tokens=max_output_tokens,
55
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
56
+ top_k=top_k,
57
+ top_p=top_p
58
+ )
59
+
60
+ if files:
61
+ text_prompt = [chatbot[-1][0]] if chatbot[-1][0] and isinstance(chatbot[-1][0], str) else []
62
+ image_prompt = [Image.open(file.name).convert('RGB') for file in files]
63
+ response = await query_ai_model_vision(
64
+ text_prompt=text_prompt,
65
+ image_prompt=image_prompt,
66
  max_output_tokens=max_output_tokens,
67
  temperature=temperature,
68
  top_p=top_p,
69
  top_k=top_k,
70
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences)
71
+ )
72
+ else:
73
+ messages = preprocess_chat_history(chatbot)
74
+ response = await query_ai_model(
75
+ message=chatbot[-1][0],
76
+ history=messages[:-1], # Exclude the last user message
77
+ max_output_tokens=max_output_tokens,
78
+ temperature=temperature,
79
+ top_p=top_p,
80
+ top_k=top_k,
81
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences)
82
  )
83
 
84
+ # Streaming effect
85
+ chatbot[-1] = (chatbot[-1][0], "")
86
+ if response['content'].startswith("An error occurred"):
87
+ chatbot[-1] = (chatbot[-1][0], response['content'])
88
+ yield chatbot
89
+ else:
90
+ for i in range(0, len(response['content']), 10):
91
+ section = response['content'][i:i + 10]
92
+ chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + section)
93
+ await asyncio.sleep(0.01)
94
+ yield chatbot
95
 
96
  def create_gradio_interface():
97
+ # Define all components first
98
+ google_key_component = gr.Textbox(
99
+ label="GOOGLE API KEY",
100
+ value="",
101
+ type="password",
102
+ placeholder="Enter your Google API key",
103
+ info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
104
+ visible=os.environ.get("GOOGLE_API_KEY") is None
105
+ )
106
+ chatbot_component = gr.Chatbot(
107
+ label='Gemini Chatbot',
108
+ height=600,
109
+ type='messages'
110
+ )
111
+ text_prompt_component = gr.Textbox(
112
+ placeholder="Hi there! [press Enter]", show_label=False
113
+ )
114
+ upload_button_component = gr.UploadButton(
115
+ label="Upload Images", file_count="multiple", file_types=["image"]
116
+ )
117
+ temperature_component = gr.Slider(
118
+ minimum=0,
119
+ maximum=1.0,
120
+ value=DEFAULT_TEMPERATURE,
121
+ step=0.05,
122
+ label="Temperature",
123
+ info="Controls the randomness of the AI's response."
124
+ )
125
+ max_output_tokens_component = gr.Slider(
126
+ minimum=1,
127
+ maximum=2048,
128
+ value=DEFAULT_MAX_OUTPUT_TOKENS,
129
+ step=1,
130
+ label="Max Output Tokens",
131
+ info="Controls the length of the AI's response."
132
+ )
133
+ stop_sequences_component = gr.Textbox(
134
+ label="Add stop sequence",
135
+ value="",
136
+ type="text",
137
+ placeholder="STOP, END",
138
+ info="Stops the AI response when any of these sequences are generated."
139
+ )
140
+ top_k_component = gr.Slider(
141
+ minimum=1,
142
+ maximum=40,
143
+ value=DEFAULT_TOP_K,
144
+ step=1,
145
+ label="Top-K",
146
+ info="Limits the next token selection to the K most probable tokens."
147
+ )
148
+ top_p_component = gr.Slider(
149
+ minimum=0,
150
+ maximum=1,
151
+ value=DEFAULT_TOP_P,
152
+ step=0.01,
153
+ label="Top-P",
154
+ info="Limits the next token selection to tokens with cumulative probability up to P."
155
+ )
156
 
157
+ with gr.Blocks() as demo:
158
+ gr.HTML("<h1 align='center'>Gemini Playground 💬</h1>")
159
+ gr.HTML("<h2 align='center'>Play with Gemini Pro and Gemini Pro Vision API</h2>")
160
+ gr.HTML("""
161
+ <div style="text-align: center; display: flex; justify-content: center; align-items: center;">
162
+ <a href="https://huggingface.co/spaces/SkalskiP/ChatGemini?duplicate=true">
163
+ <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space" style="margin-right: 10px;">
164
+ </a>
165
+ <span>Duplicate the Space and run securely with your
166
+ <a href="https://makersuite.google.com/app/apikey">GOOGLE API KEY</a>.
167
+ </span>
168
+ </div>
169
+ """)
170
+ with gr.Column():
171
+ google_key_component.render()
172
+ chatbot_component.render()
173
+ with gr.Row():
174
+ text_prompt_component.render()
175
+ upload_button_component.render()
176
+ with gr.Accordion("Parameters", open=False):
177
+ temperature_component.render()
178
+ max_output_tokens_component.render()
179
+ stop_sequences_component.render()
180
+ with gr.Accordion("Advanced", open=False):
181
+ top_k_component.render()
182
+ top_p_component.render()
183
+
184
+ # Connect the user input submission
185
+ text_prompt_component.submit(
186
+ fn=user,
187
+ inputs=[text_prompt_component, chatbot_component],
188
+ outputs=[text_prompt_component, chatbot_component],
189
+ queue=False
190
+ ).then(
191
+ fn=bot,
192
+ inputs=[
193
+ google_key_component,
194
+ upload_button_component,
195
+ temperature_component,
196
+ max_output_tokens_component,
197
+ stop_sequences_component,
198
+ top_k_component,
199
+ top_p_component,
200
+ chatbot_component
201
+ ],
202
+ outputs=chatbot_component,
203
+ queue=False
204
  )
205
 
206
+ # Handle image uploads
207
+ upload_button_component.upload(
208
+ fn=bot,
 
 
 
 
 
 
 
 
 
 
 
209
  inputs=[
 
 
210
  google_key_component,
211
+ upload_button_component,
212
  temperature_component,
213
  max_output_tokens_component,
214
  stop_sequences_component,
215
  top_k_component,
216
+ top_p_component,
217
+ chatbot_component
218
  ],
219
+ outputs=chatbot_component,
220
+ queue=False
221
  )
222
 
223
  return demo