aboalaa147 commited on
Commit
7505690
·
verified ·
1 Parent(s): de5ff91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -166
app.py CHANGED
@@ -1,183 +1,64 @@
1
- import gradio as gr
2
- import numpy as np
3
  import torch
4
- import soundfile as sf
5
  import librosa
6
- from matplotlib import pyplot as plt
7
- from transformers import AutoFeatureExtractor, AutoModelForAudioFrameClassification
8
- from recitations_segmenter import segment_recitations, clean_speech_intervals
9
- import io
10
- from PIL import Image
11
- import tempfile
12
- import os
13
- import zipfile
14
-
15
- # 🔹 ASR client to connect to Space B
16
- from gradio_client import Client, handle_file
17
 
18
- # ======================
19
- # Setup device and model
20
- # ======================
21
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
23
 
24
- print(f"Loading segmentation model on {device}...")
25
- processor = AutoFeatureExtractor.from_pretrained("obadx/recitation-segmenter-v2")
26
- model = AutoModelForAudioFrameClassification.from_pretrained(
27
- "obadx/recitation-segmenter-v2",
28
- torch_dtype=dtype,
29
- device_map=device
30
- )
31
- print("Segmentation model loaded successfully!")
32
 
33
- # 🔹 ASR Space (Space B)
34
- asr_client = Client("aboalaa1472/Quran_ASR") # لو Space B Private: pass hf_token="HF_xxx"
35
 
36
- # ======================
37
- # Utils
38
- # ======================
39
- def read_audio(path, sampling_rate=16000):
40
  audio, sr = sf.read(path)
41
- if len(audio.shape) > 1:
42
  audio = audio.mean(axis=1)
43
- if sr != sampling_rate:
44
- audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
45
- return torch.tensor(audio).float()
46
-
47
- def get_interval(x, intervals, idx, sr=16000):
48
- start = int(intervals[idx][0] * sr)
49
- end = int(intervals[idx][1] * sr)
50
- return x[start:end]
51
-
52
- def plot_signal(x, intervals, sr=16000):
53
- fig, ax = plt.subplots(figsize=(20, 4))
54
- if isinstance(x, torch.Tensor):
55
- x = x.numpy()
56
- ax.plot(x, linewidth=0.5)
57
- for s, e in intervals:
58
- ax.axvline(x=s * sr, color='red', alpha=0.4)
59
- ax.axvline(x=e * sr, color='red', alpha=0.4)
60
- plt.tight_layout()
61
-
62
- buf = io.BytesIO()
63
- plt.savefig(buf, format="png")
64
- buf.seek(0)
65
- img = Image.open(buf)
66
- plt.close()
67
- return img
68
-
69
- # ======================
70
- # Main processing
71
- # ======================
72
- def process_audio(audio_file, min_silence_ms, min_speech_ms, pad_ms):
73
- if audio_file is None:
74
- return None, "⚠️ ارفع ملف صوتي", None, []
75
-
76
  try:
77
- wav = read_audio(audio_file)
78
-
79
- sampled_outputs = segment_recitations(
80
- [wav],
81
- model,
82
- processor,
83
- device=device,
84
- dtype=dtype,
85
- batch_size=4,
86
- )
87
-
88
- clean_out = clean_speech_intervals(
89
- sampled_outputs[0].speech_intervals,
90
- sampled_outputs[0].is_complete,
91
- min_silence_duration_ms=min_silence_ms,
92
- min_speech_duration_ms=min_speech_ms,
93
- pad_duration_ms=pad_ms,
94
- return_seconds=True,
95
- )
96
-
97
- intervals = clean_out.clean_speech_intervals
98
- plot_img = plot_signal(wav, intervals)
99
-
100
- temp_dir = tempfile.mkdtemp()
101
- segment_files = []
102
- full_asr_text = []
103
-
104
- result_text = f"✅ عدد المقاطع: {len(intervals)}\n\n"
105
-
106
- for i in range(len(intervals)):
107
- seg = get_interval(wav, intervals, i)
108
- if isinstance(seg, torch.Tensor):
109
- seg = seg.cpu().numpy()
110
-
111
- seg_path = os.path.join(temp_dir, f"segment_{i+1:03d}.wav")
112
- sf.write(seg_path, seg, 16000)
113
- segment_files.append(seg_path)
114
-
115
- # 🔹 ASR call to Space B
116
- asr_text = asr_client.predict(
117
- uploaded_audio=handle_file(seg_path),
118
- mic_audio=handle_file(seg_path),
119
- api_name="/run"
120
- )
121
-
122
- full_asr_text.append(asr_text)
123
-
124
- result_text += (
125
- f"🎵 مقطع {i+1} "
126
- f"({intervals[i][0]:.2f}s → {intervals[i][1]:.2f}s)\n"
127
- f"📜 {asr_text}\n\n"
128
- )
129
-
130
- result_text += "\n🧾 النص الكامل:\n"
131
- result_text += " ".join(full_asr_text)
132
-
133
- # ZIP
134
- zip_path = os.path.join(temp_dir, "segments.zip")
135
- with zipfile.ZipFile(zip_path, 'w') as zipf:
136
- for f in segment_files:
137
- zipf.write(f, os.path.basename(f))
138
-
139
- return plot_img, result_text, zip_path, segment_files
140
-
141
  except Exception as e:
142
- return None, f" خطأ: {str(e)}", None, []
143
 
144
- # ======================
145
- # Gradio UI
146
- # ======================
147
- with gr.Blocks(title="Quran Segmentation + ASR") as demo:
148
- gr.Markdown("## 🕌 تقطيع التلاوات + ASR (Quran Text)")
149
 
150
  with gr.Row():
151
  with gr.Column():
152
- audio_input = gr.Audio(type="filepath", label="📤 ارفع التلاوة")
153
- min_silence = gr.Slider(10, 500, 30, step=10, label="Min Silence (ms)")
154
- min_speech = gr.Slider(10, 500, 30, step=10, label="Min Speech (ms)")
155
- padding = gr.Slider(0, 200, 30, step=10, label="Padding (ms)")
156
- btn = gr.Button("🚀 ابدأ")
157
-
158
  with gr.Column():
159
- plot_out = gr.Image(label="📈 الإشارة")
160
- text_out = gr.Textbox(lines=20, label="📜 النص")
161
-
162
- zip_out = gr.File(label="📦 تحميل المقاطع")
163
-
164
- segment_outputs = [gr.Audio(visible=False) for _ in range(50)]
165
-
166
- def process_and_show(audio, ms, sp, pad):
167
- plot, text, zipf, segments = process_audio(audio, ms, sp, pad)
168
- outputs = [plot, text, zipf]
169
- for i in range(50):
170
- if i < len(segments):
171
- outputs.append(gr.Audio(value=segments[i], visible=True))
172
- else:
173
- outputs.append(gr.Audio(visible=False))
174
- return outputs
175
 
176
- btn.click(
177
- process_and_show,
178
- inputs=[audio_input, min_silence, min_speech, padding],
179
- outputs=[plot_out, text_out, zip_out] + segment_outputs
180
- )
181
 
182
- if __name__ == "__main__":
183
- demo.launch()
 
1
+ import os
 
2
  import torch
 
3
  import librosa
4
+ import soundfile as sf
5
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
+ import gradio as gr
 
 
 
 
 
 
 
 
7
 
8
+ MODEL_ID = "xLeonSTES/quran-to-text-base"
9
+ SAMPLE_RATE = 16000
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
12
+ @torch.no_grad()
13
+ def load_model():
14
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
15
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_ID)
16
+ model.to(DEVICE)
17
+ model.eval()
18
+ return processor, model
 
19
 
20
+ processor, model = load_model()
 
21
 
22
+ def resample_to_16k(path):
 
 
 
23
  audio, sr = sf.read(path)
24
+ if audio.ndim > 1:
25
  audio = audio.mean(axis=1)
26
+ if sr != SAMPLE_RATE:
27
+ audio = librosa.resample(audio.astype('float32'), orig_sr=sr, target_sr=SAMPLE_RATE)
28
+ return audio, SAMPLE_RATE
29
+
30
+ def transcribe_audio(path):
31
+ audio, sr = resample_to_16k(path)
32
+ audio = audio / (max(abs(audio)) + 1e-9)
33
+ inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
34
+ input_features = inputs.input_features.to(DEVICE)
35
+
36
+ with torch.no_grad():
37
+ generated_ids = model.generate(input_features)
38
+
39
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
+ return text
41
+
42
+ def run(uploaded_audio, mic_audio):
43
+ path = mic_audio or uploaded_audio
44
+ if not path:
45
+ return "No audio provided"
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
+ return transcribe_audio(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
49
+ return f"Error: {e}"
50
 
51
+ with gr.Blocks(title="Quran ASR") as demo:
52
+ gr.Markdown("# Quran ASR — Diacritized Transcription\nUpload or record audio, then press Convert.")
 
 
 
53
 
54
  with gr.Row():
55
  with gr.Column():
56
+ upload = gr.Audio(type="filepath", label="Upload Audio")
57
+ mic = gr.Audio(type="filepath", label="Microphone Recording")
58
+ btn = gr.Button("Convert")
 
 
 
59
  with gr.Column():
60
+ out = gr.Textbox(label="Output Text", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ btn.click(run, inputs=[upload, mic], outputs=[out])
 
 
 
 
63
 
64
+ demo.launch()