Spaces:
Build error
Build error
add single channel
Browse files- Cargo.lock +1 -0
- Cargo.toml +1 -0
- src/base64box.rs +36 -0
- src/main.rs +121 -0
Cargo.lock
CHANGED
|
@@ -1546,6 +1546,7 @@ dependencies = [
|
|
| 1546 |
"aws-sdk-polly",
|
| 1547 |
"aws-sdk-transcribestreaming",
|
| 1548 |
"aws-sdk-translate",
|
|
|
|
| 1549 |
"config",
|
| 1550 |
"futures-util",
|
| 1551 |
"lazy_static",
|
|
|
|
| 1546 |
"aws-sdk-polly",
|
| 1547 |
"aws-sdk-transcribestreaming",
|
| 1548 |
"aws-sdk-translate",
|
| 1549 |
+
"base64 0.21.5",
|
| 1550 |
"config",
|
| 1551 |
"futures-util",
|
| 1552 |
"lazy_static",
|
Cargo.toml
CHANGED
|
@@ -22,6 +22,7 @@ tracing = { version = "0.1", features = [] }
|
|
| 22 |
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
| 23 |
async-trait = "0.1.74"
|
| 24 |
lazy_static = "1.4.0"
|
|
|
|
| 25 |
|
| 26 |
[features]
|
| 27 |
whisper = ["dep:whisper"]
|
|
|
|
| 22 |
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
| 23 |
async-trait = "0.1.74"
|
| 24 |
lazy_static = "1.4.0"
|
| 25 |
+
base64 = { version = "0.21.5", features = [] }
|
| 26 |
|
| 27 |
[features]
|
| 28 |
whisper = ["dep:whisper"]
|
src/base64box.rs
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
| 2 |
+
use serde::de::Error;
|
| 3 |
+
use base64::{Engine as _, alphabet, engine::{GeneralPurpose, general_purpose}};
|
| 4 |
+
use lazy_static::lazy_static;
|
| 5 |
+
|
| 6 |
+
lazy_static! {
|
| 7 |
+
static ref ENGINE: GeneralPurpose = GeneralPurpose::new(
|
| 8 |
+
&alphabet::STANDARD,
|
| 9 |
+
general_purpose::NO_PAD
|
| 10 |
+
);
|
| 11 |
+
}
|
| 12 |
+
#[derive(Debug)]
|
| 13 |
+
pub struct Base64Box(pub Vec<u8>);
|
| 14 |
+
impl Serialize for Base64Box {
|
| 15 |
+
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
| 16 |
+
serializer.collect_str(&ENGINE.encode(&self.0))
|
| 17 |
+
}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
impl<'de> Deserialize<'de> for Base64Box {
|
| 21 |
+
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
| 22 |
+
struct Vis;
|
| 23 |
+
impl serde::de::Visitor<'_> for Vis {
|
| 24 |
+
type Value = Base64Box;
|
| 25 |
+
|
| 26 |
+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
| 27 |
+
formatter.write_str("a base64 string")
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
|
| 31 |
+
ENGINE.decode(v).map(Base64Box).map_err(Error::custom)
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
deserializer.deserialize_str(Vis)
|
| 35 |
+
}
|
| 36 |
+
}
|
src/main.rs
CHANGED
|
@@ -28,10 +28,12 @@ use tracing::{debug, span};
|
|
| 28 |
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
| 29 |
|
| 30 |
use crate::{config::*, lesson::*};
|
|
|
|
| 31 |
|
| 32 |
mod config;
|
| 33 |
mod lesson;
|
| 34 |
mod asr;
|
|
|
|
| 35 |
|
| 36 |
#[derive(Clone)]
|
| 37 |
struct Context {
|
|
@@ -63,6 +65,7 @@ async fn main() -> Result<(), std::io::Error> {
|
|
| 63 |
.at("/ws/teacher", get(stream_speaker))
|
| 64 |
.at("/ws/lesson-listener", get(stream_listener))
|
| 65 |
.at("/ws/student", get(stream_listener))
|
|
|
|
| 66 |
.at(
|
| 67 |
"lesson-speaker",
|
| 68 |
StaticFileEndpoint::new("./static/index.html"),
|
|
@@ -273,6 +276,124 @@ async fn stream_listener(
|
|
| 273 |
})
|
| 274 |
}
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
fn u8_to_i16(input: &[u8]) -> Vec<i16> {
|
| 277 |
input
|
| 278 |
.chunks_exact(2)
|
|
|
|
| 28 |
use tracing_subscriber::{fmt, prelude::*, EnvFilter};
|
| 29 |
|
| 30 |
use crate::{config::*, lesson::*};
|
| 31 |
+
use crate::base64box::Base64Box;
|
| 32 |
|
| 33 |
mod config;
|
| 34 |
mod lesson;
|
| 35 |
mod asr;
|
| 36 |
+
mod base64box;
|
| 37 |
|
| 38 |
#[derive(Clone)]
|
| 39 |
struct Context {
|
|
|
|
| 65 |
.at("/ws/teacher", get(stream_speaker))
|
| 66 |
.at("/ws/lesson-listener", get(stream_listener))
|
| 67 |
.at("/ws/student", get(stream_listener))
|
| 68 |
+
.at("/ws/voice", get(stream_single))
|
| 69 |
.at(
|
| 70 |
"lesson-speaker",
|
| 71 |
StaticFileEndpoint::new("./static/index.html"),
|
|
|
|
| 276 |
})
|
| 277 |
}
|
| 278 |
|
| 279 |
+
#[derive(Serialize, Debug)]
|
| 280 |
+
enum SingleEvent {
|
| 281 |
+
#[serde(rename = "original")]
|
| 282 |
+
Transcription {
|
| 283 |
+
content: String,
|
| 284 |
+
#[serde(rename = "isFinal")]
|
| 285 |
+
is_final: bool
|
| 286 |
+
},
|
| 287 |
+
#[serde(rename = "translated")]
|
| 288 |
+
Translation { content: String },
|
| 289 |
+
#[serde(rename = "voice")]
|
| 290 |
+
Voice {
|
| 291 |
+
content: Base64Box
|
| 292 |
+
},
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
#[derive(Deserialize, Debug)]
|
| 297 |
+
pub struct SingleQuery {
|
| 298 |
+
id: u32,
|
| 299 |
+
from: String,
|
| 300 |
+
to: String,
|
| 301 |
+
voice: Option<String>,
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
#[handler]
|
| 306 |
+
async fn stream_single(
|
| 307 |
+
ctx: Data<&Context>,
|
| 308 |
+
query: Query<SingleQuery>,
|
| 309 |
+
ws: WebSocket
|
| 310 |
+
) -> impl IntoResponse {
|
| 311 |
+
let lessons_manager = ctx.lessons_manager.clone();
|
| 312 |
+
ws.on_upgrade(|mut socket| async move {
|
| 313 |
+
let Ok(lang) = query.from.parse::<LanguageCode>() else {
|
| 314 |
+
let _ = socket
|
| 315 |
+
.send(Message::Text(format!("invalid language code: {}", query.from)))
|
| 316 |
+
.await;
|
| 317 |
+
return
|
| 318 |
+
};
|
| 319 |
+
let lesson = lessons_manager
|
| 320 |
+
.create_lesson(
|
| 321 |
+
query.id,
|
| 322 |
+
AsrEngine::AWS,
|
| 323 |
+
lang,
|
| 324 |
+
)
|
| 325 |
+
.await;
|
| 326 |
+
|
| 327 |
+
let mut transcribe_rx = lesson.transcript_channel();
|
| 328 |
+
let mut lang_lesson = lesson.get_or_init(query.to.clone()).await;
|
| 329 |
+
let mut translate_rx = lang_lesson.translated_channel();
|
| 330 |
+
let Ok(voice_id) = query.voice.as_deref().unwrap_or("Amy").parse() else {
|
| 331 |
+
let _ = socket
|
| 332 |
+
.send(Message::Text(format!("invalid voice id: {:?}", query.voice)))
|
| 333 |
+
.await;
|
| 334 |
+
return
|
| 335 |
+
};
|
| 336 |
+
let mut voice_lesson = lang_lesson.get_or_init(voice_id).await;
|
| 337 |
+
let mut voice_rx = voice_lesson.voice_channel();
|
| 338 |
+
// let mut lip_sync_rx = voice_lesson.lip_sync_channel();
|
| 339 |
+
|
| 340 |
+
let fut = async {
|
| 341 |
+
loop {
|
| 342 |
+
let evt = select! {
|
| 343 |
+
input = socket.next() => {
|
| 344 |
+
let Some(res) = input else { break };
|
| 345 |
+
let msg = res?;
|
| 346 |
+
if msg.is_close() {
|
| 347 |
+
break
|
| 348 |
+
}
|
| 349 |
+
let Message::Binary(bin) = msg else {
|
| 350 |
+
tracing::warn!("Other: {:?}", msg);
|
| 351 |
+
continue
|
| 352 |
+
};
|
| 353 |
+
let frame = u8_to_i16(&bin);
|
| 354 |
+
lesson.send(frame).await?;
|
| 355 |
+
continue
|
| 356 |
+
},
|
| 357 |
+
transcript_poll = transcribe_rx.recv() => {
|
| 358 |
+
let evt = transcript_poll?;
|
| 359 |
+
if evt.is_final {
|
| 360 |
+
tracing::trace!("Transcribed: {}", evt.transcript);
|
| 361 |
+
}
|
| 362 |
+
SingleEvent::Transcription { content: evt.transcript, is_final: evt.is_final }
|
| 363 |
+
},
|
| 364 |
+
translated_poll = translate_rx.recv() => {
|
| 365 |
+
let translated = translated_poll?;
|
| 366 |
+
SingleEvent::Translation { content: translated }
|
| 367 |
+
},
|
| 368 |
+
voice_poll = voice_rx.recv() => {
|
| 369 |
+
let voice = voice_poll?;
|
| 370 |
+
SingleEvent::Voice { content: Base64Box(voice) }
|
| 371 |
+
},
|
| 372 |
+
};
|
| 373 |
+
|
| 374 |
+
let Ok(json) = serde_json::to_string(&evt) else {
|
| 375 |
+
tracing::warn!("failed to serialize json: {:?}", evt);
|
| 376 |
+
continue
|
| 377 |
+
};
|
| 378 |
+
socket.send(Message::Text(json)).await?
|
| 379 |
+
}
|
| 380 |
+
Ok(())
|
| 381 |
+
};
|
| 382 |
+
|
| 383 |
+
let span = span!(tracing::Level::TRACE, "lesson_speaker", lesson_id = query.id);
|
| 384 |
+
let _ = span.enter();
|
| 385 |
+
let res: anyhow::Result<()> = fut.await;
|
| 386 |
+
match res {
|
| 387 |
+
Ok(()) => {
|
| 388 |
+
tracing::info!("lesson speaker closed");
|
| 389 |
+
}
|
| 390 |
+
Err(e) => {
|
| 391 |
+
tracing::warn!("lesson speaker error: {}", e);
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
})
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
fn u8_to_i16(input: &[u8]) -> Vec<i16> {
|
| 398 |
input
|
| 399 |
.chunks_exact(2)
|