Nicolás Larenas commited on
Commit
a6692e9
·
verified ·
1 Parent(s): 56ce1d2

Update gradio_interface.py

Browse files
Files changed (1) hide show
  1. gradio_interface.py +26 -182
gradio_interface.py CHANGED
@@ -1,190 +1,34 @@
1
- # gradio_interface.py
2
-
3
  import gradio as gr
4
- from ai_model import query_ai_model, preprocess_chat_history
5
- import asyncio
6
- import logging
7
- from config import (
8
- DEFAULT_MAX_OUTPUT_TOKENS,
9
- DEFAULT_TEMPERATURE,
10
- DEFAULT_TOP_P,
11
- DEFAULT_TOP_K,
12
- )
13
- import os
14
- from typing import List, Optional # <-- Added this import
15
- import time
16
-
17
- # Configure logging
18
- logging.basicConfig(level=logging.ERROR)
19
-
20
- def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
21
- if not stop_sequences:
22
- return None
23
- return [sequence.strip() for sequence in stop_sequences.split(",")]
24
-
25
- def preprocess_chat_history(
26
- history: List[tuple]
27
- ) -> List[dict]:
28
- messages = []
29
- for user_message, model_message in history:
30
- if user_message is not None:
31
- messages.append({'role': 'user', 'content': user_message})
32
- if model_message is not None:
33
- messages.append({'role': 'assistant', 'content': model_message})
34
- return messages
35
-
36
- def user(text_prompt: str, chatbot: List[tuple]):
37
- if text_prompt:
38
- chatbot.append((text_prompt, None))
39
- return "", chatbot
40
-
41
- async def bot(
42
- google_key: str,
43
- temperature: float,
44
- max_output_tokens: int,
45
- stop_sequences: str,
46
- top_k: int,
47
- top_p: float,
48
- chatbot: List[tuple]
49
- ):
50
- if len(chatbot) == 0:
51
- yield chatbot
52
- return
53
-
54
- google_key = google_key if google_key else os.environ.get("GOOGLE_API_KEY")
55
- if not google_key:
56
- chatbot.append(('GOOGLE_API_KEY is not set. Please provide your API key.', None))
57
- yield chatbot
58
- return
59
-
60
- genai.configure(api_key=google_key)
61
 
62
- stop_sequences_list = preprocess_stop_sequences(stop_sequences)
63
-
64
- try:
65
- # Process the latest user message
66
- message = chatbot[-1][0]
67
- history = preprocess_chat_history(chatbot[:-1]) # Exclude the last message
68
-
69
- # Query the AI model
70
- assistant_reply = await query_ai_model(
71
- message=message,
72
- history=history,
73
- max_output_tokens=max_output_tokens,
74
- temperature=temperature,
75
- top_p=top_p,
76
- top_k=top_k,
77
- stop_sequences=stop_sequences_list
78
- )
79
-
80
- # Update the chat history with the assistant's reply
81
- chatbot[-1] = (chatbot[-1][0], "")
82
- response_content = assistant_reply['content']
83
-
84
- # Streaming effect
85
- for i in range(0, len(response_content), 10):
86
- section = response_content[i:i + 10]
87
- chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + section)
88
- await asyncio.sleep(0.01)
89
- yield chatbot
90
- except Exception as e:
91
- logging.error("Error in bot function", exc_info=True)
92
- error_message = f"An error occurred: {str(e)}"
93
- chatbot[-1] = (chatbot[-1][0], error_message)
94
- yield chatbot
95
 
 
96
  def create_gradio_interface():
97
- # Define all components first
98
- google_key_component = gr.Textbox(
99
- label="GOOGLE API KEY",
100
- value="",
101
- type="password",
102
- placeholder="Enter your Google API key",
103
- info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
104
- visible=os.environ.get("GOOGLE_API_KEY") is None
105
- )
106
- chatbot_component = gr.Chatbot(
107
- label='Gemini Chatbot',
108
- height=600,
109
- type='messages'
110
- )
111
- text_prompt_component = gr.Textbox(
112
- placeholder="Hi there! [press Enter]", show_label=False
113
- )
114
- temperature_component = gr.Slider(
115
- minimum=0,
116
- maximum=1.0,
117
- value=DEFAULT_TEMPERATURE,
118
- step=0.05,
119
- label="Temperature",
120
- info="Controls the randomness of the AI's response."
121
- )
122
- max_output_tokens_component = gr.Slider(
123
- minimum=1,
124
- maximum=2048,
125
- value=DEFAULT_MAX_OUTPUT_TOKENS,
126
- step=1,
127
- label="Max Output Tokens",
128
- info="Controls the length of the AI's response."
129
- )
130
- stop_sequences_component = gr.Textbox(
131
- label="Add stop sequence",
132
- value="",
133
- type="text",
134
- placeholder="STOP, END",
135
- info="Stops the AI response when any of these sequences are generated."
136
- )
137
- top_k_component = gr.Slider(
138
- minimum=1,
139
- maximum=40,
140
- value=DEFAULT_TOP_K,
141
- step=1,
142
- label="Top-K",
143
- info="Limits the next token selection to the K most probable tokens."
144
- )
145
- top_p_component = gr.Slider(
146
- minimum=0,
147
- maximum=1,
148
- value=DEFAULT_TOP_P,
149
- step=0.01,
150
- label="Top-P",
151
- info="Limits the next token selection to tokens with cumulative probability up to P."
152
- )
153
-
154
  with gr.Blocks() as demo:
155
- gr.HTML("<h1 align='center'>Gemini Chat Interface 💬</h1>")
156
- gr.HTML("<h2 align='center'>Interact with the Gemini AI model.</h2>")
157
- with gr.Column():
158
- google_key_component.render()
159
- chatbot_component.render()
160
- text_prompt_component.render()
161
- with gr.Accordion("Parameters", open=False):
162
- temperature_component.render()
163
- max_output_tokens_component.render()
164
- stop_sequences_component.render()
165
- with gr.Accordion("Advanced", open=False):
166
- top_k_component.render()
167
- top_p_component.render()
168
-
169
- # Connect the user input submission
170
- text_prompt_component.submit(
171
- fn=user,
172
- inputs=[text_prompt_component, chatbot_component],
173
- outputs=[text_prompt_component, chatbot_component],
174
- queue=False
175
- ).then(
176
- fn=bot,
177
- inputs=[
178
- google_key_component,
179
- temperature_component,
180
- max_output_tokens_component,
181
- stop_sequences_component,
182
- top_k_component,
183
- top_p_component,
184
- chatbot_component
185
- ],
186
- outputs=chatbot_component,
187
- queue=False
188
  )
189
 
190
  return demo
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from ai_model import query_ai_model # Import AI model query function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # Function to handle chatbot conversation
5
+ def gradio_chatbot(text, history):
6
+ history.append((text, None))
7
+ response = asyncio.run(query_ai_model(text))
8
+ history[-1] = (text, response)
9
+ return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Gradio interface
12
  def create_gradio_interface():
13
+ chatbot = gr.Chatbot()
14
+ textbox = gr.Textbox(placeholder="Ask a question and press Enter", show_label=False)
15
+
16
+ # Set up the Gradio interface layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  with gr.Blocks() as demo:
18
+ gr.HTML("<h1>Unified Gemini Chatbot</h1>")
19
+ chatbot.render()
20
+ textbox.render()
21
+
22
+ # Bind input textbox to chatbot response
23
+ textbox.submit(
24
+ fn=gradio_chatbot,
25
+ inputs=[textbox, chatbot],
26
+ outputs=chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
29
  return demo
30
+
31
+ # Launch Gradio interface
32
+ def run_gradio_interface():
33
+ interface = create_gradio_interface()
34
+ interface.launch(share=True, debug=True)