mkfallah commited on
Commit
ce98ad4
Β·
verified Β·
1 Parent(s): 0ec3d43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -108
app.py CHANGED
@@ -1,152 +1,84 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
- from rapidfuzz import process, fuzz
4
- import soundfile as sf
5
- import numpy as np
6
  import torch
 
7
 
8
- # ----------------------------
9
- # 1) ASR pipeline (Whisper Persian)
10
- # ----------------------------
11
  asr = pipeline(
12
- "automatic-speech-recognition",
13
  model="vhdm/whisper-large-fa-v1",
14
- device=-1, # CPU. set 0 for GPU
15
  )
16
 
17
- # ----------------------------
18
- # 2) LLM (text generation)
19
- # ----------------------------
20
- llm_model_id = "tiiuae/falcon-rw-1b" # choose a model that fits your env
21
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
22
  llm_model = AutoModelForCausalLM.from_pretrained(
23
  llm_model_id,
24
- torch_dtype=torch.float32,
25
  ).to("cpu")
26
 
27
- def ask_llm(prompt: str, max_new_tokens: int = 200) -> str:
28
  inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
29
  with torch.no_grad():
30
  outputs = llm_model.generate(**inputs, max_new_tokens=max_new_tokens)
31
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
32
 
33
- # ----------------------------
34
- # 3) TTS pipeline (SpeechT5 via transformers pipeline)
35
- # ----------------------------
36
- # Use the text-to-speech pipeline which handles preprocessing and speaker defaults.
37
- tts_pipeline = pipeline("text-to-speech", model="microsoft/speecht5_tts", device=-1) # device=-1 -> CPU
38
-
39
- def text_to_speech_save(text: str, out_path: str = "response.wav") -> str:
40
- """
41
- Use the text-to-speech pipeline to synthesize `text` and save to `out_path`.
42
- Returns the path on success or raises exception on failure.
43
- """
44
- # pipeline may return a dict or list depending on versions; handle both
45
- result = tts_pipeline(text)
46
- if isinstance(result, list):
47
- entry = result[0]
48
- else:
49
- entry = result
50
- audio = entry.get("audio") if isinstance(entry, dict) else None
51
- sr = entry.get("sampling_rate", 16000) if isinstance(entry, dict) else 16000
52
-
53
- if audio is None:
54
- # some pipeline versions return numpy array directly
55
- audio = result if isinstance(result, np.ndarray) else None
56
 
57
- if audio is None:
58
- raise RuntimeError("TTS pipeline returned no audio.")
59
 
60
- # ensure numpy array
61
- audio_np = np.asarray(audio)
62
- sf.write(out_path, audio_np, sr)
 
 
63
  return out_path
64
 
65
- # ----------------------------
66
- # 4) Fuzzy replacement (robust)
67
- # ----------------------------
68
- custom_vocab_map = {
69
- "Ω†Ψ±Ψ―": ["Ω†Ψ±Ψ―", "نِرد", "Ω†ΩŽΨ±Ψ―"],
70
- "Ϊ©Ψ§Ω…ΩΎΫŒΩˆΨͺΨ±": ["Ϊ©Ψ§Ω…ΩΎΫŒΩˆΨͺΨ±", "Ϊ©Ψ§Ω…ΩΎΫŒΩˆΨͺΨ±Ω‡"],
71
- "Ω‡ΩˆΨ΄ Ω…Ψ΅Ω†ΩˆΨΉΫŒ": ["Ω‡ΩˆΨ΄ Ω…Ψ΅Ω†ΩˆΨΉΫŒ", "Ω‡ΩˆΨ΄ Ψ΅Ω†ΨΉΨͺی"],
72
- "Ω…Ψ§Ψ΄ΫŒΩ†": ["Ω…Ψ§Ψ΄ΫŒΩ†", "Ω…Ψ§Ψ΄ΫŒΩ†Ω‡"],
73
- }
74
-
75
- def replace_fuzzy(text: str, vocab_map: dict, threshold: int = 85) -> str:
76
- """
77
- Replace near-matches in `text` with canonical targets from vocab_map.
78
- Handles rapidfuzz.extractOne return types (object or tuple).
79
- """
80
- if not text:
81
- return text
82
- for target, alternatives in vocab_map.items():
83
- try:
84
- res = process.extractOne(text, alternatives, scorer=fuzz.partial_ratio)
85
- except Exception:
86
- res = None
87
- if not res:
88
- continue
89
- # res may be an Extracted object or tuple
90
- if hasattr(res, "value") and hasattr(res, "score"):
91
- match = res.value
92
- score = res.score
93
- else:
94
- # tuple like (match, score, idx) or (match, score)
95
- match = res[0]
96
- score = res[1] if len(res) > 1 else 0
97
- if score >= threshold:
98
- # replace only the first occurrence to avoid accidental global replacement
99
- text = text.replace(match, target, 1)
100
- return text
101
-
102
- # ----------------------------
103
- # 5) Full pipeline function
104
- # ----------------------------
105
- def full_pipeline(audio_file: str):
106
- """
107
- audio_file is a filepath (Gradio with type='filepath' sends a path for mic/upload).
108
- Returns (text_output_str, path_to_tts_wav or None).
109
- """
110
  if not audio_file:
111
  return "No audio input detected.", None
112
 
113
- # 1) ASR
114
  try:
115
- asr_result = asr(audio_file, chunk_length_s=30, stride_length_s=[5, 5])
116
  except Exception as e:
117
  return f"ASR error: {e}", None
118
 
119
- raw_text = asr_result.get("text", "")
120
- if raw_text is None:
121
- raw_text = ""
122
-
123
- # 2) fuzzy replacement
124
- corrected_text = replace_fuzzy(raw_text, custom_vocab_map, threshold=85)
125
 
126
- # 3) LLM reply
127
  try:
128
- llm_reply = ask_llm(corrected_text)
129
  except Exception as e:
130
- llm_reply = f"LLM error: {e}"
131
 
132
- # 4) TTS (synthesize LLM reply)
133
  try:
134
- audio_out = text_to_speech_save(llm_reply, out_path="response.wav")
135
  except Exception as e:
136
- return f"User said: {corrected_text}\nAssistant generation error: {e}", None
137
 
138
- convo = f"User said: {corrected_text}\nAssistant: {llm_reply}"
139
- return convo, audio_out
140
 
141
- # ----------------------------
142
- # 6) Gradio UI
143
- # ----------------------------
144
  iface = gr.Interface(
145
  fn=full_pipeline,
146
  inputs=gr.Audio(type="filepath", label="Record or upload audio"),
147
  outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
148
  title="Persian Voice Assistant",
149
- description="ASR β†’ LLM β†’ TTS (offline-ready pipelines).",
150
  )
151
 
152
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech
 
 
 
3
  import torch
4
+ import soundfile as sf
5
 
6
+ # --------------------------
7
+ # 1. ASR (speech to text)
8
+ # --------------------------
9
  asr = pipeline(
10
+ task="automatic-speech-recognition",
11
  model="vhdm/whisper-large-fa-v1",
12
+ device=-1
13
  )
14
 
15
+ # --------------------------
16
+ # 2. Language Model (LLM)
17
+ # --------------------------
18
+ llm_model_id = "tiiuae/falcon-rw-1b"
19
  tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
20
  llm_model = AutoModelForCausalLM.from_pretrained(
21
  llm_model_id,
22
+ torch_dtype=torch.float32
23
  ).to("cpu")
24
 
25
+ def ask_llm(prompt, max_new_tokens=200):
26
  inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
27
  with torch.no_grad():
28
  outputs = llm_model.generate(**inputs, max_new_tokens=max_new_tokens)
29
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
31
+ # --------------------------
32
+ # 3. TTS (text-to-speech) using SpeechT5
33
+ # --------------------------
34
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
35
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Random speaker embedding (can be replaced with a fixed one for consistency)
38
+ speaker_embedding = torch.randn(1, 512)
39
 
40
+ def text_to_speech(text, out_path="output.wav"):
41
+ inputs = processor(text=text, return_tensors="pt")
42
+ with torch.no_grad():
43
+ speech = tts_model.generate_speech(inputs["input_ids"], speaker_embedding)
44
+ sf.write(out_path, speech.numpy(), 16000)
45
  return out_path
46
 
47
+ # --------------------------
48
+ # 4. Full pipeline function
49
+ # --------------------------
50
+ def full_pipeline(audio_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if not audio_file:
52
  return "No audio input detected.", None
53
 
 
54
  try:
55
+ result = asr(audio_file, chunk_length_s=30, stride_length_s=[5, 5])
56
  except Exception as e:
57
  return f"ASR error: {e}", None
58
 
59
+ user_text = result.get("text", "")
 
 
 
 
 
60
 
 
61
  try:
62
+ llm_response = ask_llm(user_text)
63
  except Exception as e:
64
+ return f"Assistant generation error: {e}", None
65
 
 
66
  try:
67
+ audio_path = text_to_speech(llm_response, "response.wav")
68
  except Exception as e:
69
+ return f"TTS error: {e}", None
70
 
71
+ return f"User said: {user_text}\nAssistant: {llm_response}", audio_path
 
72
 
73
+ # --------------------------
74
+ # 5. Gradio Interface
75
+ # --------------------------
76
  iface = gr.Interface(
77
  fn=full_pipeline,
78
  inputs=gr.Audio(type="filepath", label="Record or upload audio"),
79
  outputs=[gr.Textbox(label="Conversation"), gr.Audio(label="TTS Response")],
80
  title="Persian Voice Assistant",
81
+ description="ASR β†’ LLM β†’ TTS"
82
  )
83
 
84
  if __name__ == "__main__":