datbkpro commited on
Commit
cc93a2a
·
verified ·
1 Parent(s): 6e7d07a

Update services/sambanova_voice_service.py

Browse files
Files changed (1) hide show
  1. services/sambanova_voice_service.py +74 -87
services/sambanova_voice_service.py CHANGED
@@ -12,13 +12,14 @@ from fastrtc import (
12
  ReplyOnStopWords,
13
  Stream,
14
  get_stt_model,
15
- get_twilio_turn_credentials,
16
  )
17
  from gradio.utils import get_space
18
  from pydantic import BaseModel
 
19
 
20
  class SambanovaVoiceService:
21
- """Dịch vụ Voice AI với Sambanova API"""
22
 
23
  def __init__(self):
24
  self.curr_dir = Path(__file__).parent
@@ -32,12 +33,21 @@ class SambanovaVoiceService:
32
  # STT model
33
  self.model = get_stt_model()
34
 
35
- # RTC configuration
36
- self.rtc_configuration = get_twilio_turn_credentials() if get_space() else None
37
-
38
- # FastAPI app
39
- self.app = FastAPI()
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def create_response_handler(self):
42
  """Tạo response handler cho voice streaming"""
43
 
@@ -49,33 +59,49 @@ class SambanovaVoiceService:
49
  gradio_chatbot = gradio_chatbot or []
50
  conversation_state = conversation_state or []
51
 
52
- # Speech to Text
53
- text = self.model.stt(audio)
54
- print("🎤 STT Result:", text)
55
-
56
- # Thêm audio vào chatbot
57
- sample_rate, array = audio
58
- gradio_chatbot.append(
59
- {"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
60
- )
61
- yield AdditionalOutputs(gradio_chatbot, conversation_state)
 
 
 
 
 
62
 
63
- # Thêm text vào conversation state
64
- conversation_state.append({"role": "user", "content": text})
65
 
66
- # Gọi Sambanova API
67
- request = self.client.chat.completions.create(
68
- model="Meta-Llama-3.2-3B-Instruct",
69
- messages=conversation_state,
70
- temperature=0.1,
71
- top_p=0.1,
72
- )
73
- response_content = {"role": "assistant", "content": request.choices[0].message.content}
 
 
 
 
74
 
75
- conversation_state.append(response_content)
76
- gradio_chatbot.append(response_content)
77
 
78
- yield AdditionalOutputs(gradio_chatbot, conversation_state)
 
 
 
 
 
 
 
79
 
80
  return response
81
 
@@ -86,65 +112,26 @@ class SambanovaVoiceService:
86
  return Stream(
87
  ReplyOnStopWords(
88
  response_handler,
89
- stop_words=["computer", "hey", "hello", "xin chào"],
90
  input_sample_rate=16000,
91
  ),
92
  mode="send",
93
  modality="audio",
94
- additional_inputs=[gr.Chatbot(type="messages", value=[]), gr.State(value=[])],
95
- additional_outputs=[gr.Chatbot(type="messages", value=[]), gr.State(value=[])],
96
- additional_outputs_handler=lambda *a: (a[2], a[3]),
97
- concurrency_limit=5 if get_space() else None,
98
- time_limit=90 if get_space() else None,
 
 
 
 
 
 
 
 
 
 
 
99
  rtc_configuration=self.rtc_configuration,
100
- )
101
-
102
- def setup_fastapi_routes(self):
103
- """Thiết lập FastAPI routes"""
104
-
105
- class Message(BaseModel):
106
- role: str
107
- content: str
108
-
109
- class InputData(BaseModel):
110
- webrtc_id: str
111
- chatbot: list[Message]
112
- state: list[Message]
113
-
114
- @self.app.get("/")
115
- async def home():
116
- rtc_config = get_twilio_turn_credentials() if get_space() else None
117
- html_content = (self.curr_dir / "templates" / "sambanova_index.html").read_text()
118
- html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
119
- return HTMLResponse(content=html_content)
120
-
121
- @self.app.post("/input_hook")
122
- async def input_hook(data: InputData):
123
- body = data.model_dump()
124
- # stream.set_input(data.webrtc_id, body["chatbot"], body["state"])
125
- return {"status": "ok"}
126
-
127
- def audio_to_base64(file_path):
128
- audio_format = "wav"
129
- with open(file_path, "rb") as audio_file:
130
- encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
131
- return f"data:audio/{audio_format};base64,{encoded_audio}"
132
-
133
- @self.app.get("/outputs")
134
- async def outputs(webrtc_id: str):
135
- async def output_stream():
136
- # async for output in stream.output_stream(webrtc_id):
137
- # chatbot = output.args[0]
138
- # state = output.args[1]
139
- # data = {
140
- # "message": state[-1],
141
- # "audio": audio_to_base64(chatbot[-1]["content"].value["path"])
142
- # if chatbot[-1]["role"] == "user"
143
- # else None,
144
- # }
145
- # yield f"event: output\ndata: {json.dumps(data)}\n\n"
146
- yield f"event: output\ndata: {json.dumps({'message': 'Stream ready'})}\n\n"
147
-
148
- return StreamingResponse(output_stream(), media_type="text/event-stream")
149
-
150
- return self.app
 
12
  ReplyOnStopWords,
13
  Stream,
14
  get_stt_model,
15
+ get_cloudflare_turn_credentials_async, # Sử dụng Cloudflare free
16
  )
17
  from gradio.utils import get_space
18
  from pydantic import BaseModel
19
+ import asyncio
20
 
21
  class SambanovaVoiceService:
22
+ """Dịch vụ Voice AI với Sambanova API - Fixed TURN issue"""
23
 
24
  def __init__(self):
25
  self.curr_dir = Path(__file__).parent
 
33
  # STT model
34
  self.model = get_stt_model()
35
 
36
+ # RTC configuration - Sử dụng Cloudflare free hoặc None
37
+ self.rtc_configuration = asyncio.run(self._get_turn_config())
 
 
 
38
 
39
+ print("✅ Sambanova Voice Service initialized")
40
+
41
+ async def _get_turn_config(self):
42
+ """Lấy TURN configuration - sử dụng Cloudflare free"""
43
+ try:
44
+ config = await get_cloudflare_turn_credentials_async()
45
+ print("✅ Using Cloudflare TURN servers")
46
+ return config
47
+ except Exception as e:
48
+ print(f"⚠️ Cannot get TURN credentials, using None: {e}")
49
+ return None # Sẽ hoạt động trên local network
50
+
51
  def create_response_handler(self):
52
  """Tạo response handler cho voice streaming"""
53
 
 
59
  gradio_chatbot = gradio_chatbot or []
60
  conversation_state = conversation_state or []
61
 
62
+ try:
63
+ # Speech to Text
64
+ text = self.model.stt(audio)
65
+ print("🎤 STT Result:", text)
66
+
67
+ if not text.strip():
68
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
69
+ return
70
+
71
+ # Thêm audio vào chatbot
72
+ sample_rate, array = audio
73
+ gradio_chatbot.append(
74
+ {"role": "user", "content": f"🎤: {text}"} # Simplified - chỉ hiển thị text
75
+ )
76
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
77
 
78
+ # Thêm text vào conversation state
79
+ conversation_state.append({"role": "user", "content": text})
80
 
81
+ # Gọi Sambanova API
82
+ print("🤖 Calling Sambanova API...")
83
+ request = self.client.chat.completions.create(
84
+ model="Meta-Llama-3.2-3B-Instruct",
85
+ messages=conversation_state,
86
+ temperature=0.1,
87
+ top_p=0.1,
88
+ )
89
+ response_content = {
90
+ "role": "assistant",
91
+ "content": request.choices[0].message.content
92
+ }
93
 
94
+ conversation_state.append(response_content)
95
+ gradio_chatbot.append(response_content)
96
 
97
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
98
+
99
+ except Exception as e:
100
+ print(f"❌ Error in response handler: {e}")
101
+ error_msg = {"role": "assistant", "content": f"❌ Lỗi: {str(e)}"}
102
+ gradio_chatbot.append(error_msg)
103
+ conversation_state.append(error_msg)
104
+ yield AdditionalOutputs(gradio_chatbot, conversation_state)
105
 
106
  return response
107
 
 
112
  return Stream(
113
  ReplyOnStopWords(
114
  response_handler,
115
+ stop_words=["computer", "hey", "hello", "xin chào", "llama"],
116
  input_sample_rate=16000,
117
  ),
118
  mode="send",
119
  modality="audio",
120
+ additional_inputs=[
121
+ gr.Chatbot(
122
+ type="messages",
123
+ value=[],
124
+ label="💬 Voice Conversation",
125
+ height=400
126
+ ),
127
+ gr.State(value=[])
128
+ ],
129
+ additional_outputs=[
130
+ gr.Chatbot(type="messages", value=[]),
131
+ gr.State(value=[])
132
+ ],
133
+ additional_outputs_handler=lambda chatbot, state, new_chatbot, new_state: (new_chatbot, new_state),
134
+ concurrency_limit=3,
135
+ time_limit=120,
136
  rtc_configuration=self.rtc_configuration,
137
+ )