aboalaa147 commited on
Commit
b2cde20
·
verified ·
1 Parent(s): 30f4a9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -47
app.py CHANGED
@@ -1,64 +1,183 @@
1
- import os
 
2
  import torch
3
- import librosa
4
  import soundfile as sf
5
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
- import gradio as gr
 
 
 
 
 
 
 
7
 
8
- MODEL_ID = "xLeonSTES/quran-to-text-base"
9
- SAMPLE_RATE = 16000
10
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- @torch.no_grad()
13
- def load_model():
14
- processor = AutoProcessor.from_pretrained(MODEL_ID)
15
- model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_ID)
16
- model.to(DEVICE)
17
- model.eval()
18
- return processor, model
19
 
20
- processor, model = load_model()
 
 
 
 
 
 
 
21
 
22
- def resample_to_16k(path):
 
 
 
 
 
 
23
  audio, sr = sf.read(path)
24
- if audio.ndim > 1:
25
  audio = audio.mean(axis=1)
26
- if sr != SAMPLE_RATE:
27
- audio = librosa.resample(audio.astype('float32'), orig_sr=sr, target_sr=SAMPLE_RATE)
28
- return audio, SAMPLE_RATE
29
-
30
- def transcribe_audio(path):
31
- audio, sr = resample_to_16k(path)
32
- audio = audio / (max(abs(audio)) + 1e-9)
33
- inputs = processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
34
- input_features = inputs.input_features.to(DEVICE)
35
-
36
- with torch.no_grad():
37
- generated_ids = model.generate(input_features)
38
-
39
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
- return text
41
-
42
- def run(uploaded_audio, mic_audio):
43
- path = mic_audio or uploaded_audio
44
- if not path:
45
- return "No audio provided"
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
- return transcribe_audio(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
49
- return f"Error: {e}"
50
 
51
- with gr.Blocks(title="Quran ASR") as demo:
52
- gr.Markdown("# Quran ASR — Diacritized Transcription\nUpload or record audio, then press Convert.")
 
 
 
53
 
54
  with gr.Row():
55
  with gr.Column():
56
- upload = gr.Audio(type="filepath", label="Upload Audio")
57
- mic = gr.Audio(type="filepath", label="Microphone Recording")
58
- btn = gr.Button("Convert")
 
 
 
59
  with gr.Column():
60
- out = gr.Textbox(label="Output Text", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- btn.click(run, inputs=[upload, mic], outputs=[out])
 
 
 
 
63
 
64
- demo.launch()
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
  import torch
 
4
  import soundfile as sf
5
+ import librosa
6
+ from matplotlib import pyplot as plt
7
+ from transformers import AutoFeatureExtractor, AutoModelForAudioFrameClassification
8
+ from recitations_segmenter import segment_recitations, clean_speech_intervals
9
+ import io
10
+ from PIL import Image
11
+ import tempfile
12
+ import os
13
+ import zipfile
14
 
15
+ # 🔹 ASR client to connect to Space B
16
+ from gradio_client import Client, handle_file
 
17
 
18
+ # ======================
19
+ # Setup device and model
20
+ # ======================
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
23
 
24
+ print(f"Loading segmentation model on {device}...")
25
+ processor = AutoFeatureExtractor.from_pretrained("obadx/recitation-segmenter-v2")
26
+ model = AutoModelForAudioFrameClassification.from_pretrained(
27
+ "obadx/recitation-segmenter-v2",
28
+ torch_dtype=dtype,
29
+ device_map=device
30
+ )
31
+ print("Segmentation model loaded successfully!")
32
 
33
+ # 🔹 ASR Space (Space B)
34
+ asr_client = Client("aboalaa1472/Quran_ASR") # لو Space B Private: pass hf_token="HF_xxx"
35
+
36
+ # ======================
37
+ # Utils
38
+ # ======================
39
+ def read_audio(path, sampling_rate=16000):
40
  audio, sr = sf.read(path)
41
+ if len(audio.shape) > 1:
42
  audio = audio.mean(axis=1)
43
+ if sr != sampling_rate:
44
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
45
+ return torch.tensor(audio).float()
46
+
47
+ def get_interval(x, intervals, idx, sr=16000):
48
+ start = int(intervals[idx][0] * sr)
49
+ end = int(intervals[idx][1] * sr)
50
+ return x[start:end]
51
+
52
+ def plot_signal(x, intervals, sr=16000):
53
+ fig, ax = plt.subplots(figsize=(20, 4))
54
+ if isinstance(x, torch.Tensor):
55
+ x = x.numpy()
56
+ ax.plot(x, linewidth=0.5)
57
+ for s, e in intervals:
58
+ ax.axvline(x=s * sr, color='red', alpha=0.4)
59
+ ax.axvline(x=e * sr, color='red', alpha=0.4)
60
+ plt.tight_layout()
61
+
62
+ buf = io.BytesIO()
63
+ plt.savefig(buf, format="png")
64
+ buf.seek(0)
65
+ img = Image.open(buf)
66
+ plt.close()
67
+ return img
68
+
69
+ # ======================
70
+ # Main processing
71
+ # ======================
72
+ def process_audio(audio_file, min_silence_ms, min_speech_ms, pad_ms):
73
+ if audio_file is None:
74
+ return None, "⚠️ ارفع ملف صوتي", None, []
75
+
76
  try:
77
+ wav = read_audio(audio_file)
78
+
79
+ sampled_outputs = segment_recitations(
80
+ [wav],
81
+ model,
82
+ processor,
83
+ device=device,
84
+ dtype=dtype,
85
+ batch_size=4,
86
+ )
87
+
88
+ clean_out = clean_speech_intervals(
89
+ sampled_outputs[0].speech_intervals,
90
+ sampled_outputs[0].is_complete,
91
+ min_silence_duration_ms=min_silence_ms,
92
+ min_speech_duration_ms=min_speech_ms,
93
+ pad_duration_ms=pad_ms,
94
+ return_seconds=True,
95
+ )
96
+
97
+ intervals = clean_out.clean_speech_intervals
98
+ plot_img = plot_signal(wav, intervals)
99
+
100
+ temp_dir = tempfile.mkdtemp()
101
+ segment_files = []
102
+ full_asr_text = []
103
+
104
+ result_text = f"✅ عدد المقاطع: {len(intervals)}\n\n"
105
+
106
+ for i in range(len(intervals)):
107
+ seg = get_interval(wav, intervals, i)
108
+ if isinstance(seg, torch.Tensor):
109
+ seg = seg.cpu().numpy()
110
+
111
+ seg_path = os.path.join(temp_dir, f"segment_{i+1:03d}.wav")
112
+ sf.write(seg_path, seg, 16000)
113
+ segment_files.append(seg_path)
114
+
115
+ # 🔹 ASR call to Space B
116
+ asr_text = asr_client.predict(
117
+ uploaded_audio=handle_file(seg_path),
118
+ mic_audio=handle_file(seg_path),
119
+ api_name="/run"
120
+ )
121
+
122
+ full_asr_text.append(asr_text)
123
+
124
+ result_text += (
125
+ f"🎵 مقطع {i+1} "
126
+ f"({intervals[i][0]:.2f}s → {intervals[i][1]:.2f}s)\n"
127
+ f"📜 {asr_text}\n\n"
128
+ )
129
+
130
+ result_text += "\n🧾 النص الكامل:\n"
131
+ result_text += " ".join(full_asr_text)
132
+
133
+ # ZIP
134
+ zip_path = os.path.join(temp_dir, "segments.zip")
135
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
136
+ for f in segment_files:
137
+ zipf.write(f, os.path.basename(f))
138
+
139
+ return plot_img, result_text, zip_path, segment_files
140
+
141
  except Exception as e:
142
+ return None, f" خطأ: {str(e)}", None, []
143
 
144
+ # ======================
145
+ # Gradio UI
146
+ # ======================
147
+ with gr.Blocks(title="Quran Segmentation + ASR") as demo:
148
+ gr.Markdown("## 🕌 تقطيع التلاوات + ASR (Quran Text)")
149
 
150
  with gr.Row():
151
  with gr.Column():
152
+ audio_input = gr.Audio(type="filepath", label="📤 ارفع التلاوة")
153
+ min_silence = gr.Slider(10, 500, 30, step=10, label="Min Silence (ms)")
154
+ min_speech = gr.Slider(10, 500, 30, step=10, label="Min Speech (ms)")
155
+ padding = gr.Slider(0, 200, 30, step=10, label="Padding (ms)")
156
+ btn = gr.Button("🚀 ابدأ")
157
+
158
  with gr.Column():
159
+ plot_out = gr.Image(label="📈 الإشارة")
160
+ text_out = gr.Textbox(lines=20, label="📜 النص")
161
+
162
+ zip_out = gr.File(label="📦 تحميل المقاطع")
163
+
164
+ segment_outputs = [gr.Audio(visible=False) for _ in range(50)]
165
+
166
+ def process_and_show(audio, ms, sp, pad):
167
+ plot, text, zipf, segments = process_audio(audio, ms, sp, pad)
168
+ outputs = [plot, text, zipf]
169
+ for i in range(50):
170
+ if i < len(segments):
171
+ outputs.append(gr.Audio(value=segments[i], visible=True))
172
+ else:
173
+ outputs.append(gr.Audio(visible=False))
174
+ return outputs
175
 
176
+ btn.click(
177
+ process_and_show,
178
+ inputs=[audio_input, min_silence, min_speech, padding],
179
+ outputs=[plot_out, text_out, zip_out] + segment_outputs
180
+ )
181
 
182
+ if __name__ == "__main__":
183
+ demo.launch()