legolasyiu commited on
Commit
738d49d
·
verified ·
1 Parent(s): 9b1e5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -38
app.py CHANGED
@@ -3,15 +3,15 @@ import torch
3
  import librosa
4
  import soundfile as sf
5
  import tempfile
6
- import os
7
 
8
  from transformers import (
9
  AutoProcessor,
10
  AutoModelForImageTextToText,
11
  AutoTokenizer,
12
- AutoModelForCausalLM,
13
  )
14
 
 
 
15
  # -----------------------------
16
  # CONFIG
17
  # -----------------------------
@@ -19,35 +19,47 @@ STT_MODEL_ID = "EpistemeAI/Audiogemma-3N-finetune"
19
  TTS_MODEL_ID = "EpistemeAI/LexiVox"
20
 
21
  TARGET_SR = 16000
 
 
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
24
 
25
  # -----------------------------
26
- # LOAD MODELS (ONCE)
27
  # -----------------------------
28
  print("Loading STT model...")
29
  processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
30
- model = AutoModelForImageTextToText.from_pretrained(
 
31
  STT_MODEL_ID,
32
  torch_dtype="auto",
33
  device_map="auto",
34
  )
35
 
36
- print("Loading TTS model...")
 
 
 
 
 
 
37
  tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
38
- tts_model = AutoModelForCausalLM.from_pretrained(
39
- TTS_MODEL_ID,
40
- torch_dtype="auto",
41
- )
42
 
43
- def transcribe_and_translate(audio_file):
44
- if audio_file is None:
45
- return "Please upload an audio file."
 
 
 
46
 
47
- # Save temp file path
48
- audio_path = audio_file
49
 
50
- prompt = f"Transcribe the audio accurately in German."
 
 
 
 
51
 
52
  messages = [
53
  {
@@ -55,7 +67,7 @@ def transcribe_and_translate(audio_file):
55
  "content": [
56
  {"type": "audio", "audio": audio_path},
57
  {"type": "text", "text": prompt},
58
- ]
59
  }
60
  ]
61
 
@@ -63,54 +75,58 @@ def transcribe_and_translate(audio_file):
63
  messages,
64
  add_generation_prompt=True,
65
  tokenize=True,
 
66
  return_dict=True,
67
- return_tensors="pt"
68
  )
69
 
70
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
71
 
72
- with torch.no_grad():
73
- outputs = model.generate(
74
  **inputs,
75
  max_new_tokens=MAX_TOKENS,
76
  do_sample=False,
77
  temperature=0.2,
78
  )
79
 
80
- decoded = processor.batch_decode(
81
  outputs,
82
  skip_special_tokens=True,
83
- clean_up_tokenization_spaces=True
84
- )
85
 
86
- return decoded[0]
87
 
88
  # -----------------------------
89
- # PIPELINE FUNCTION
90
  # -----------------------------
91
  def speech_to_speech(audio_file):
92
  if audio_file is None:
93
  return "", None
94
 
95
- # Load + resample
96
- audio, sr = librosa.load(audio_file, sr=TARGET_SR)
97
 
98
  # ---------- STT ----------
99
-
100
- transcription = transcribe_and_translate(audio_file)
101
 
102
  # ---------- TTS ----------
103
  tts_inputs = tts_tokenizer(
104
  transcription,
105
  return_tensors="pt",
106
- )
107
 
108
- with torch.no_grad():
109
- speech = tts_model.generate(**tts_inputs)
 
 
 
 
 
110
 
111
- audio_out = speech.cpu().numpy().squeeze()
112
 
113
- # Save temp wav
114
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
115
  sf.write(tmp.name, audio_out, TARGET_SR)
116
 
@@ -119,14 +135,13 @@ def speech_to_speech(audio_file):
119
  # -----------------------------
120
  # GRADIO UI
121
  # -----------------------------
122
- with gr.Blocks(title="Audiogemma → LexiVox Speech Loop") as demo:
123
  gr.Markdown(
124
  """
125
  # 🎙️ Speech → Text → Speech
126
- **Audiogemma-3N + LexiVox**
127
 
128
- Upload audio or use the microphone.
129
- The system transcribes speech, then speaks it back using an LLM-based TTS.
130
  """
131
  )
132
 
 
3
  import librosa
4
  import soundfile as sf
5
  import tempfile
 
6
 
7
  from transformers import (
8
  AutoProcessor,
9
  AutoModelForImageTextToText,
10
  AutoTokenizer,
 
11
  )
12
 
13
+ from unsloth import FastLanguageModel
14
+
15
  # -----------------------------
16
  # CONFIG
17
  # -----------------------------
 
19
  TTS_MODEL_ID = "EpistemeAI/LexiVox"
20
 
21
  TARGET_SR = 16000
22
+ MAX_TOKENS = 512
23
+
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
26
 
27
  # -----------------------------
28
+ # LOAD STT MODEL
29
  # -----------------------------
30
  print("Loading STT model...")
31
  processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
32
+
33
+ stt_model = AutoModelForImageTextToText.from_pretrained(
34
  STT_MODEL_ID,
35
  torch_dtype="auto",
36
  device_map="auto",
37
  )
38
 
39
+ stt_model.eval()
40
+
41
+ # -----------------------------
42
+ # LOAD TTS MODEL (UNSLOTH)
43
+ # -----------------------------
44
+ print("Loading TTS model with Unsloth...")
45
+
46
  tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
 
 
 
 
47
 
48
+ tts_model, _ = FastLanguageModel.from_pretrained(
49
+ model_name = TTS_MODEL_ID,
50
+ max_seq_length = 4096,
51
+ dtype = DTYPE,
52
+ load_in_4bit = True,
53
+ )
54
 
55
+ FastLanguageModel.for_inference(tts_model)
56
+ tts_model.eval()
57
 
58
+ # -----------------------------
59
+ # STT FUNCTION
60
+ # -----------------------------
61
+ def transcribe(audio_path):
62
+ prompt = "Transcribe the audio accurately in German."
63
 
64
  messages = [
65
  {
 
67
  "content": [
68
  {"type": "audio", "audio": audio_path},
69
  {"type": "text", "text": prompt},
70
+ ],
71
  }
72
  ]
73
 
 
75
  messages,
76
  add_generation_prompt=True,
77
  tokenize=True,
78
+ return_tensors="pt",
79
  return_dict=True,
 
80
  )
81
 
82
+ inputs = {k: v.to(stt_model.device) for k, v in inputs.items()}
83
 
84
+ with torch.inference_mode():
85
+ outputs = stt_model.generate(
86
  **inputs,
87
  max_new_tokens=MAX_TOKENS,
88
  do_sample=False,
89
  temperature=0.2,
90
  )
91
 
92
+ text = processor.batch_decode(
93
  outputs,
94
  skip_special_tokens=True,
95
+ clean_up_tokenization_spaces=True,
96
+ )[0]
97
 
98
+ return text
99
 
100
  # -----------------------------
101
+ # SPEECH → SPEECH PIPELINE
102
  # -----------------------------
103
  def speech_to_speech(audio_file):
104
  if audio_file is None:
105
  return "", None
106
 
107
+ # Ensure audio is valid
108
+ _audio, _ = librosa.load(audio_file, sr=TARGET_SR)
109
 
110
  # ---------- STT ----------
111
+ transcription = transcribe(audio_file)
 
112
 
113
  # ---------- TTS ----------
114
  tts_inputs = tts_tokenizer(
115
  transcription,
116
  return_tensors="pt",
117
+ ).to(tts_model.device)
118
 
119
+ with torch.inference_mode():
120
+ speech_tokens = tts_model.generate(
121
+ **tts_inputs,
122
+ max_new_tokens=2048,
123
+ do_sample=False,
124
+ temperature=0.7,
125
+ )
126
 
127
+ audio_out = speech_tokens.cpu().numpy().squeeze()
128
 
129
+ # Save temporary WAV
130
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
131
  sf.write(tmp.name, audio_out, TARGET_SR)
132
 
 
135
  # -----------------------------
136
  # GRADIO UI
137
  # -----------------------------
138
+ with gr.Blocks(title="Audiogemma → LexiVox (Unsloth)") as demo:
139
  gr.Markdown(
140
  """
141
  # 🎙️ Speech → Text → Speech
142
+ **Audiogemma-3N + LexiVox (Unsloth Accelerated)**
143
 
144
+ Upload audio or use your microphone.
 
145
  """
146
  )
147