jeremierostan commited on
Commit
05a24ef
·
verified ·
1 Parent(s): 87ab4e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -14
app.py CHANGED
@@ -3,7 +3,7 @@ import base64
3
  import json
4
  import os
5
  import pathlib
6
- from typing import AsyncGenerator, Literal, Dict, List
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -52,7 +52,6 @@ class GeminiHandler(AsyncStreamHandler):
52
  expected_layout: Literal["mono"] = "mono",
53
  output_sample_rate: int = 24000,
54
  output_frame_size: int = 480,
55
- system_prompt: str = SYSTEM_PROMPTS["default"]
56
  ) -> None:
57
  super().__init__(
58
  expected_layout,
@@ -63,14 +62,13 @@ class GeminiHandler(AsyncStreamHandler):
63
  self.input_queue: asyncio.Queue = asyncio.Queue()
64
  self.output_queue: asyncio.Queue = asyncio.Queue()
65
  self.quit: asyncio.Event = asyncio.Event()
66
- self.system_prompt = system_prompt
67
 
68
  def copy(self) -> "GeminiHandler":
69
  return GeminiHandler(
70
  expected_layout="mono",
71
  output_sample_rate=self.output_sample_rate,
72
  output_frame_size=self.output_frame_size,
73
- system_prompt=self.system_prompt
74
  )
75
 
76
  async def start_up(self):
@@ -85,16 +83,14 @@ class GeminiHandler(AsyncStreamHandler):
85
  self.system_prompt = custom_prompt
86
  else:
87
  api_key, voice_name = None, "Puck"
 
88
 
89
  client = genai.Client(
90
  api_key=api_key or os.getenv("GEMINI_API_KEY"),
91
  http_options={"api_version": "v1alpha"},
92
  )
93
 
94
- # Convert the system prompt to a list as required by the API
95
- system_instruction_list = [self.system_prompt]
96
-
97
- # Create config with system_instruction (not system_instructions)
98
  config = LiveConnectConfig(
99
  response_modalities=["AUDIO"], # type: ignore
100
  speech_config=SpeechConfig(
@@ -104,12 +100,41 @@ class GeminiHandler(AsyncStreamHandler):
104
  )
105
  )
106
  ),
107
- system_instruction=system_instruction_list, # Note: singular "instruction" and passed as a list
108
  )
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  async with client.aio.live.connect(
111
  model="gemini-2.0-flash-exp", config=config
112
  ) as session:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  async for audio in session.start_stream(
114
  stream=self.stream(), mime_type="audio/pcm"
115
  ):
@@ -168,8 +193,8 @@ stream = Stream(
168
  value="default",
169
  ),
170
  gr.Textbox(
171
- label="Custom Prompt (overrides preset if not empty)",
172
- placeholder="Enter a custom system prompt",
173
  value="",
174
  ),
175
  ],
@@ -180,8 +205,8 @@ class InputData(BaseModel):
180
  webrtc_id: str
181
  voice_name: str
182
  api_key: str
183
- prompt_key: str
184
- custom_prompt: str
185
 
186
 
187
  app = FastAPI()
@@ -200,7 +225,6 @@ async def index():
200
  rtc_config = get_twilio_turn_credentials() if get_space() else None
201
  html_content = (current_dir / "index.html").read_text()
202
  html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
203
- # Also inject the system prompts into the HTML
204
  html_content = html_content.replace("__SYSTEM_PROMPTS__", json.dumps(SYSTEM_PROMPTS))
205
  return HTMLResponse(content=html_content)
206
 
 
3
  import json
4
  import os
5
  import pathlib
6
+ from typing import AsyncGenerator, Literal
7
 
8
  import gradio as gr
9
  import numpy as np
 
52
  expected_layout: Literal["mono"] = "mono",
53
  output_sample_rate: int = 24000,
54
  output_frame_size: int = 480,
 
55
  ) -> None:
56
  super().__init__(
57
  expected_layout,
 
62
  self.input_queue: asyncio.Queue = asyncio.Queue()
63
  self.output_queue: asyncio.Queue = asyncio.Queue()
64
  self.quit: asyncio.Event = asyncio.Event()
65
+ self.system_prompt = None
66
 
67
  def copy(self) -> "GeminiHandler":
68
  return GeminiHandler(
69
  expected_layout="mono",
70
  output_sample_rate=self.output_sample_rate,
71
  output_frame_size=self.output_frame_size,
 
72
  )
73
 
74
  async def start_up(self):
 
83
  self.system_prompt = custom_prompt
84
  else:
85
  api_key, voice_name = None, "Puck"
86
+ self.system_prompt = None
87
 
88
  client = genai.Client(
89
  api_key=api_key or os.getenv("GEMINI_API_KEY"),
90
  http_options={"api_version": "v1alpha"},
91
  )
92
 
93
+ # Create basic config
 
 
 
94
  config = LiveConnectConfig(
95
  response_modalities=["AUDIO"], # type: ignore
96
  speech_config=SpeechConfig(
 
100
  )
101
  )
102
  ),
 
103
  )
104
 
105
+ # Get model reference
106
+ model = client.get_model("gemini-2.0-flash-exp")
107
+
108
+ # Apply system prompt if available
109
+ if self.system_prompt:
110
+ try:
111
+ # First try with system_instruction method (newer API versions)
112
+ model = model.with_system_instructions(self.system_prompt)
113
+ print(f"Using system prompt via with_system_instructions: {self.system_prompt[:50]}...")
114
+ except Exception as e:
115
+ print(f"Could not apply system prompt via with_system_instructions: {e}")
116
+ # If that fails, we'll handle it in the session
117
+ pass
118
+
119
+ # Create session
120
  async with client.aio.live.connect(
121
  model="gemini-2.0-flash-exp", config=config
122
  ) as session:
123
+ # If we couldn't set the system prompt earlier and we have one,
124
+ # try to send it as the first message
125
+ if self.system_prompt:
126
+ try:
127
+ # Try to send system prompt as first message
128
+ await session.send_message(f"SYSTEM: {self.system_prompt}\n\nPlease acknowledge this system instruction.")
129
+ # Wait for a response
130
+ async for response in session.stream_response():
131
+ # Just need one response to confirm it was received
132
+ print("System instruction acknowledged")
133
+ break
134
+ except Exception as e:
135
+ print(f"Could not send system prompt as message: {e}")
136
+
137
+ # Now start the audio stream
138
  async for audio in session.start_stream(
139
  stream=self.stream(), mime_type="audio/pcm"
140
  ):
 
193
  value="default",
194
  ),
195
  gr.Textbox(
196
+ label="Custom Prompt",
197
+ placeholder="Enter a custom system prompt (overrides preset if not empty)",
198
  value="",
199
  ),
200
  ],
 
205
  webrtc_id: str
206
  voice_name: str
207
  api_key: str
208
+ prompt_key: str = ""
209
+ custom_prompt: str = ""
210
 
211
 
212
  app = FastAPI()
 
225
  rtc_config = get_twilio_turn_credentials() if get_space() else None
226
  html_content = (current_dir / "index.html").read_text()
227
  html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
 
228
  html_content = html_content.replace("__SYSTEM_PROMPTS__", json.dumps(SYSTEM_PROMPTS))
229
  return HTMLResponse(content=html_content)
230