Nicolás Larenas commited on
Commit
3ffcf44
·
verified ·
1 Parent(s): dc0560b

Update gradio_interface.py

Browse files
Files changed (1) hide show
  1. gradio_interface.py +101 -193
gradio_interface.py CHANGED
@@ -2,224 +2,132 @@ import gradio as gr
2
  from ai_model import query_ai_model
3
  import asyncio
4
  from config import (
5
- SYSTEM_MESSAGE,
6
- DEFAULT_MAX_NEW_TOKENS,
7
  DEFAULT_TEMPERATURE,
8
  DEFAULT_TOP_P,
9
  DEFAULT_TOP_K,
10
  )
11
  import os
12
- from typing import List, Tuple, Optional, Union
13
- from PIL import Image
14
-
15
- IMAGE_CACHE_DIRECTORY = "/tmp"
16
- IMAGE_WIDTH = 512
17
- CHAT_HISTORY = List[Tuple[Optional[Union[Tuple[str], str]], Optional[str]]]
18
 
19
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
20
  if not stop_sequences:
21
  return None
22
  return [sequence.strip() for sequence in stop_sequences.split(",")]
23
 
24
- def preprocess_image(image: Image.Image) -> Image.Image:
25
- image_height = int(image.height * IMAGE_WIDTH / image.width)
26
- return image.resize((IMAGE_WIDTH, image_height))
27
-
28
- def cache_pil_image(image: Image.Image) -> str:
29
- import uuid
30
- image_filename = f"{uuid.uuid4()}.jpeg"
31
- os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
32
- image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
33
- image.save(image_path, "JPEG")
34
- return image_path
35
-
36
- def upload(files: Optional[List[gr.File]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
37
- if files is not None:
38
- for file in files:
39
- image = Image.open(file.name).convert('RGB')
40
- image = preprocess_image(image)
41
- image_path = cache_pil_image(image)
42
- chatbot.append(((image_path,), None))
43
- return chatbot
44
-
45
- def user(text_prompt: str, chatbot: CHAT_HISTORY):
46
- if text_prompt:
47
- chatbot.append((text_prompt, None))
48
- return "", chatbot
49
-
50
- # Function to handle chatbot conversation
51
- def gradio_chatbot(
52
- google_key: str,
53
- files: Optional[List[gr.File]],
54
- temperature: float,
55
- max_new_tokens: int,
56
- stop_sequences: str,
57
- top_k: int,
58
- top_p: float,
59
- chatbot: CHAT_HISTORY
60
- ):
61
- if len(chatbot) == 0:
62
- return chatbot
63
 
 
 
 
 
 
64
  google_key = google_key if google_key else os.environ.get("GOOGLE_API_KEY")
65
  if not google_key:
66
- raise ValueError(
67
- "GOOGLE_API_KEY is not set. Please provide your API key."
68
- )
69
 
70
  genai.configure(api_key=google_key)
 
71
  stop_sequences_list = preprocess_stop_sequences(stop_sequences)
72
 
73
- # Prepare messages for AI model
74
- messages = []
75
- for user_message, model_message in chatbot:
76
- if isinstance(user_message, tuple):
77
- # Handle image inputs if necessary
78
- pass
79
- elif user_message is not None:
80
- messages.append({'role': 'user', 'content': user_message})
81
- if model_message is not None:
82
- messages.append({'role': 'assistant', 'content': model_message})
83
-
84
- # Handle vision model if files are uploaded
85
- use_vision = bool(files)
86
- image_files = [file.name for file in files] if files else None
87
-
88
- # Run AI model
89
- loop = asyncio.get_event_loop()
90
- response_text = loop.run_until_complete(
91
- query_ai_model(
92
- message=messages[-1]['content'],
93
- history=messages[:-1],
94
- system_message=SYSTEM_MESSAGE,
95
- max_new_tokens=max_new_tokens,
96
- temperature=temperature,
97
- top_p=top_p,
98
- top_k=top_k,
99
- stop_sequences=stop_sequences_list,
100
- )
101
  )
102
- chatbot[-1] = (chatbot[-1][0], response_text)
103
- return chatbot
104
 
105
- def create_gradio_interface():
106
- google_key_component = gr.Textbox(
107
- label="GOOGLE API KEY",
108
- value="",
109
- type="password",
110
- placeholder="...",
111
- info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
112
- visible=os.environ.get("GOOGLE_API_KEY") is None
113
- )
114
- chatbot_component = gr.Chatbot(
115
- label='Gemini',
116
- avatar_images=None,
117
- height=400,
118
- type='messages' # Use the 'messages' type for better compatibility
119
- )
120
- text_prompt_component = gr.Textbox(
121
- placeholder="Hi there! [press Enter]", show_label=False
122
- )
123
- upload_button_component = gr.UploadButton(
124
- label="Upload Images", file_count="multiple", file_types=["image"]
125
- )
126
- temperature_component = gr.Slider(
127
- minimum=0,
128
- maximum=1.0,
129
- value=DEFAULT_TEMPERATURE,
130
- step=0.05,
131
- label="Temperature",
132
- info="Controls the randomness of the AI's response."
133
- )
134
- max_output_tokens_component = gr.Slider(
135
- minimum=1,
136
- maximum=2048,
137
- value=DEFAULT_MAX_NEW_TOKENS,
138
- step=1,
139
- label="Max new tokens",
140
- info="Controls the length of the AI's response."
141
- )
142
- stop_sequences_component = gr.Textbox(
143
- label="Add stop sequence",
144
- value="",
145
- type="text",
146
- placeholder="STOP, END",
147
- info="Stops the AI response when any of these sequences are generated."
148
- )
149
- top_k_component = gr.Slider(
150
- minimum=1,
151
- maximum=40,
152
- value=DEFAULT_TOP_K,
153
- step=1,
154
- label="Top-K",
155
- info="Limits the next token selection to the K most probable tokens."
156
- )
157
- top_p_component = gr.Slider(
158
- minimum=0,
159
- maximum=1,
160
- value=DEFAULT_TOP_P,
161
- step=0.01,
162
- label="Top-P",
163
- info="Limits the next token selection to tokens with cumulative probability up to P."
164
- )
165
 
166
- user_inputs = [
167
- text_prompt_component,
168
- chatbot_component
169
- ]
170
-
171
- bot_inputs = [
172
- google_key_component,
173
- upload_button_component,
174
- temperature_component,
175
- max_output_tokens_component,
176
- stop_sequences_component,
177
- top_k_component,
178
- top_p_component,
179
- chatbot_component
180
- ]
181
 
 
182
  with gr.Blocks() as demo:
183
- gr.HTML("<h1 align='center'>Gemini Playground 💬</h1>")
184
- gr.HTML("<h2 align='center'>Play with Gemini Pro and Gemini Pro Vision API</h2>")
185
- gr.HTML("""
186
- <div style="text-align: center; display: flex; justify-content: center; align-items: center;">
187
- <a href="https://huggingface.co/spaces/SkalskiP/ChatGemini?duplicate=true">
188
- <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space" style="margin-right: 10px;">
189
- </a>
190
- <span>Duplicate the Space and run securely with your
191
- <a href="https://makersuite.google.com/app/apikey">GOOGLE API KEY</a>.
192
- </span>
193
- </div>
194
- """)
195
- with gr.Column():
196
- google_key_component.render()
197
- chatbot_component.render()
198
- with gr.Row():
199
- text_prompt_component.render()
200
- upload_button_component.render()
201
- with gr.Accordion("Parameters", open=False):
202
- temperature_component.render()
203
- max_output_tokens_component.render()
204
- stop_sequences_component.render()
205
- with gr.Accordion("Advanced", open=False):
206
- top_k_component.render()
207
- top_p_component.render()
208
-
209
- text_prompt_component.submit(
210
- fn=user,
211
- inputs=user_inputs,
212
- outputs=[text_prompt_component, chatbot_component],
213
- queue=False
214
- ).then(
215
- fn=gradio_chatbot, inputs=bot_inputs, outputs=[chatbot_component],
216
  )
217
 
218
- upload_button_component.upload(
219
- fn=upload,
220
- inputs=[upload_button_component, chatbot_component],
221
- outputs=[chatbot_component],
222
- queue=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  )
224
 
225
  return demo
 
2
  from ai_model import query_ai_model
3
  import asyncio
4
  from config import (
5
+ DEFAULT_MAX_OUTPUT_TOKENS,
 
6
  DEFAULT_TEMPERATURE,
7
  DEFAULT_TOP_P,
8
  DEFAULT_TOP_K,
9
  )
10
  import os
11
+ from typing import List, Optional
 
 
 
 
 
12
 
13
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
14
  if not stop_sequences:
15
  return None
16
  return [sequence.strip() for sequence in stop_sequences.split(",")]
17
 
18
+ def submit_message(message, chat_history):
19
+ chat_history = chat_history or []
20
+ chat_history.append({'role': 'user', 'content': message})
21
+ return "", chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ async def gradio_chatbot(
24
+ message, chat_history, google_key, temperature, max_output_tokens,
25
+ stop_sequences, top_k, top_p
26
+ ):
27
+ # message is not used here since it's already added in submit_message
28
  google_key = google_key if google_key else os.environ.get("GOOGLE_API_KEY")
29
  if not google_key:
30
+ return chat_history + [{'role': 'assistant', 'content': 'GOOGLE_API_KEY is not set. Please provide your API key.'}]
 
 
31
 
32
  genai.configure(api_key=google_key)
33
+
34
  stop_sequences_list = preprocess_stop_sequences(stop_sequences)
35
 
36
+ # Query the AI model
37
+ assistant_reply = await query_ai_model(
38
+ message=chat_history[-1]['content'], # Last user message
39
+ history=chat_history[:-1], # Exclude the last message
40
+ max_output_tokens=max_output_tokens,
41
+ temperature=temperature,
42
+ top_p=top_p,
43
+ top_k=top_k,
44
+ stop_sequences=stop_sequences_list,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
 
 
46
 
47
+ # Append assistant's reply to the history
48
+ chat_history.append(assistant_reply)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ return chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def create_gradio_interface():
53
  with gr.Blocks() as demo:
54
+ gr.HTML("<h1 align='center'>Gemini Chat Interface 💬</h1>")
55
+ gr.HTML("<h2 align='center'>Interact with the Gemini AI model.</h2>")
56
+
57
+ google_key_component = gr.Textbox(
58
+ label="GOOGLE API KEY",
59
+ value="",
60
+ type="password",
61
+ placeholder="Enter your Google API key",
62
+ visible=os.environ.get("GOOGLE_API_KEY") is None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
 
65
+ chatbot_component = gr.Chatbot(
66
+ label='Gemini Chatbot',
67
+ height=600,
68
+ type='messages'
69
+ )
70
+
71
+ with gr.Row():
72
+ msg = gr.Textbox(
73
+ label="Your Message",
74
+ placeholder="Type your message here and press Enter"
75
+ )
76
+
77
+ with gr.Accordion("Advanced Parameters", open=False):
78
+ temperature_component = gr.Slider(
79
+ minimum=0,
80
+ maximum=1.0,
81
+ value=DEFAULT_TEMPERATURE,
82
+ step=0.05,
83
+ label="Temperature"
84
+ )
85
+ max_output_tokens_component = gr.Slider(
86
+ minimum=1,
87
+ maximum=2048,
88
+ value=DEFAULT_MAX_OUTPUT_TOKENS,
89
+ step=1,
90
+ label="Max Output Tokens"
91
+ )
92
+ stop_sequences_component = gr.Textbox(
93
+ label="Stop Sequences",
94
+ value="",
95
+ placeholder="Comma-separated stop sequences"
96
+ )
97
+ top_k_component = gr.Slider(
98
+ minimum=1,
99
+ maximum=40,
100
+ value=DEFAULT_TOP_K,
101
+ step=1,
102
+ label="Top-K"
103
+ )
104
+ top_p_component = gr.Slider(
105
+ minimum=0,
106
+ maximum=1,
107
+ value=DEFAULT_TOP_P,
108
+ step=0.01,
109
+ label="Top-P"
110
+ )
111
+
112
+ state = gr.State()
113
+
114
+ msg.submit(
115
+ fn=submit_message,
116
+ inputs=[msg, chatbot_component],
117
+ outputs=[msg, chatbot_component]
118
+ ).then(
119
+ fn=gradio_chatbot,
120
+ inputs=[
121
+ msg,
122
+ chatbot_component,
123
+ google_key_component,
124
+ temperature_component,
125
+ max_output_tokens_component,
126
+ stop_sequences_component,
127
+ top_k_component,
128
+ top_p_component
129
+ ],
130
+ outputs=chatbot_component
131
  )
132
 
133
  return demo