aboalaa147's picture
Update app.py
3b84f27 verified
import gradio as gr
import numpy as np
import torch
import soundfile as sf
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioFrameClassification
from recitations_segmenter import segment_recitations, clean_speech_intervals
import io
from PIL import Image
import tempfile
import os
import zipfile
# 🔹 ASR client
from gradio_client import Client, handle_file
# 🔹 Arabic Aligner
from arabic_aligner import ArabicAligner # الملف اللي فيه الكود اللي بعتته قبل كده
# ======================
# Setup device and model
# ======================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
print(f"Loading model on {device}...")
processor = AutoFeatureExtractor.from_pretrained("obadx/recitation-segmenter-v2")
model = AutoModelForAudioFrameClassification.from_pretrained(
"obadx/recitation-segmenter-v2",
torch_dtype=dtype,
device_map=device
)
print("Model loaded successfully!")
# 🔹 ASR Space
asr_client = Client("aboalaa1472/Quran_ASR")
# ======================
# Utils
# ======================
def read_audio(path, sampling_rate=16000):
audio, sr = sf.read(path)
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
if sr != sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
return torch.tensor(audio).float()
def get_interval(x, intervals, idx, sr=16000):
start = int(intervals[idx][0] * sr)
end = int(intervals[idx][1] * sr)
return x[start:end]
def plot_signal(x, intervals, sr=16000):
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(20, 4))
if isinstance(x, torch.Tensor):
x = x.numpy()
ax.plot(x, linewidth=0.5)
for s, e in intervals:
ax.axvline(x=s * sr, color='red', alpha=0.4)
ax.axvline(x=e * sr, color='red', alpha=0.4)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
# ======================
# Main processing
# ======================
def process_audio_and_compare(audio_file, reference_text, min_silence_ms, min_speech_ms, pad_ms):
if audio_file is None:
return None, "⚠️ ارفع ملف صوتي أولاً", None
try:
wav = read_audio(audio_file)
sampled_outputs = segment_recitations(
[wav],
model,
processor,
device=device,
dtype=dtype,
batch_size=4,
)
clean_out = clean_speech_intervals(
sampled_outputs[0].speech_intervals,
sampled_outputs[0].is_complete,
min_silence_duration_ms=min_silence_ms,
min_speech_duration_ms=min_speech_ms,
pad_duration_ms=pad_ms,
return_seconds=True,
)
intervals = clean_out.clean_speech_intervals
plot_img = plot_signal(wav, intervals)
temp_dir = tempfile.mkdtemp()
segment_files = []
full_asr_text = []
result_text = f"✅ عدد المقاطع: {len(intervals)}\n\n"
for i in range(len(intervals)):
seg = get_interval(wav, intervals, i)
if isinstance(seg, torch.Tensor):
seg = seg.cpu().numpy()
seg_path = os.path.join(temp_dir, f"segment_{i+1:03d}.wav")
sf.write(seg_path, seg, 16000)
segment_files.append(seg_path)
# 🔹 ASR CALL
asr_text = asr_client.predict(
uploaded_audio=handle_file(seg_path),
mic_audio=handle_file(seg_path),
api_name="/run"
)
full_asr_text.append(asr_text)
result_text += f"🎵 مقطع {i+1} ({intervals[i][0]:.2f}s → {intervals[i][1]:.2f}s)\n📜 {asr_text}\n\n"
full_asr_text_str = " ".join(full_asr_text)
result_text += f"\n🧾 النص الكامل:\n{full_asr_text_str}\n\n"
# 🔹 ArabicAligner comparison
aligner = ArabicAligner()
align_results = aligner.align_and_compare(full_asr_text_str, reference_text)
stats = align_results['statistics']
result_text += (
f"📊 إحصائيات المقارنة:\n"
f"- إجمالي كلمات المرجع: {stats['total_reference_words']}\n"
f"- إجمالي كلمات ASR: {stats['total_user_words']}\n"
f"- إجمالي الأخطاء: {stats['total_errors']}\n"
f" - أخطاء الكلمات: {stats['word_level_errors']}\n"
f" - أخطاء الحركات: {stats['diacritic_errors']}\n"
f"- الدقة: {stats['accuracy']:.2f}%\n\n"
f"✏️ تفاصيل الأخطاء:\n"
)
for i, error in enumerate(align_results['errors'], 1):
result_text += f"[{i}] Type: {error.error_type.value.upper()} | User: '{error.user_word}' | Expected: '{error.reference_word}' | Details: {error.details}\n"
# ZIP
zip_path = os.path.join(temp_dir, "segments.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for f in segment_files:
zipf.write(f, os.path.basename(f))
return plot_img, result_text, zip_path
except Exception as e:
return None, f"❌ خطأ: {str(e)}", None
# Gradio UI
# ======================
with gr.Blocks(title="Quran Segmentation + ASR + Comparison") as demo:
gr.Markdown("## 🕌 تقطيع التلاوات + التعرف على النص القرآني + المقارنة بالنص المشكول")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="📤 ارفع التلاوة")
reference_text_input = gr.Textbox(label="📖 أدخل نص القرآن المشكول للمقارنة", lines=10)
min_silence = gr.Slider(10, 500, 30, step=10, label="Min Silence (ms)")
min_speech = gr.Slider(10, 500, 30, step=10, label="Min Speech (ms)")
padding = gr.Slider(0, 200, 30, step=10, label="Padding (ms)")
btn = gr.Button("🚀 ابدأ")
with gr.Column():
plot_out = gr.Image(label="📈 الإشارة")
text_out = gr.Textbox(lines=30, label="📜 النتائج")
zip_out = gr.File(label="📦 تحميل المقاطع")
btn.click(
fn=process_audio_and_compare,
inputs=[audio_input, reference_text_input, min_silence, min_speech, padding],
outputs=[plot_out, text_out, zip_out]
)
if __name__ == "__main__":
demo.launch()