legolasyiu commited on
Commit
8c47ec1
·
verified ·
1 Parent(s): 430aac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -16
app.py CHANGED
@@ -40,6 +40,51 @@ tts_model = AutoModelForCausalLM.from_pretrained(
40
  torch_dtype="auto",
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # -----------------------------
44
  # PIPELINE FUNCTION
45
  # -----------------------------
@@ -51,23 +96,8 @@ def speech_to_speech(audio_file):
51
  audio, sr = librosa.load(audio_file, sr=TARGET_SR)
52
 
53
  # ---------- STT ----------
54
- stt_inputs = stt_processor(
55
- audio=audio,
56
- sampling_rate=TARGET_SR,
57
- text="Transcribe the audio accurately.",
58
- return_tensors="pt",
59
- ).to(DEVICE)
60
-
61
- with torch.no_grad():
62
- output_ids = stt_model.generate(
63
- **stt_inputs,
64
- max_new_tokens=512,
65
- )
66
 
67
- transcription = stt_processor.decode(
68
- output_ids[0],
69
- skip_special_tokens=True,
70
- )
71
 
72
  # ---------- TTS ----------
73
  tts_inputs = tts_tokenizer(
 
40
  torch_dtype="auto",
41
  )
42
 
43
+ def transcribe_and_translate(audio_file):
44
+ if audio_file is None:
45
+ return "Please upload an audio file."
46
+
47
+ # Save temp file path
48
+ audio_path = audio_file
49
+
50
+ prompt = f"Transcribe the audio accurately."
51
+
52
+ messages = [
53
+ {
54
+ "role": "user",
55
+ "content": [
56
+ {"type": "audio", "audio": audio_path},
57
+ {"type": "text", "text": prompt},
58
+ ]
59
+ }
60
+ ]
61
+
62
+ inputs = processor.apply_chat_template(
63
+ messages,
64
+ add_generation_prompt=True,
65
+ tokenize=True,
66
+ return_dict=True,
67
+ return_tensors="pt"
68
+ )
69
+
70
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
71
+
72
+ with torch.no_grad():
73
+ outputs = model.generate(
74
+ **inputs,
75
+ max_new_tokens=MAX_TOKENS,
76
+ do_sample=False,
77
+ temperature=0.2,
78
+ )
79
+
80
+ decoded = processor.batch_decode(
81
+ outputs,
82
+ skip_special_tokens=True,
83
+ clean_up_tokenization_spaces=True
84
+ )
85
+
86
+ return decoded[0]
87
+
88
  # -----------------------------
89
  # PIPELINE FUNCTION
90
  # -----------------------------
 
96
  audio, sr = librosa.load(audio_file, sr=TARGET_SR)
97
 
98
  # ---------- STT ----------
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ transcription = transcribe_and_translate(audio_file)
 
 
 
101
 
102
  # ---------- TTS ----------
103
  tts_inputs = tts_tokenizer(