row_test_audio / app.py
TD-jayadeera's picture
Update app.py
410fdf8 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import numpy as np
import math
import os
import shutil
from datetime import datetime
from pathlib import Path
from urllib.parse import quote
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from huggingface_hub import HfApi, hf_hub_download
# ==========================================
# 1. මොඩලයේ ව්‍යුහය (Architecture)
# ==========================================
class SelfAttentionPooling(nn.Module):
def __init__(self, input_dim):
super(SelfAttentionPooling, self).__init__()
self.W = nn.Linear(input_dim, 128)
self.V = nn.Linear(128, 1)
def forward(self, x, attention_mask=None):
scores = self.V(torch.tanh(self.W(x)))
if attention_mask is not None:
indices = torch.linspace(0, attention_mask.size(1) - 1, steps=x.size(1)).long().to(x.device)
mask = torch.index_select(attention_mask, 1, indices).unsqueeze(-1)
scores = scores.masked_fill(mask == 0, -1e4)
attn_weights = F.softmax(scores, dim=1)
return torch.sum(x * attn_weights, dim=1), attn_weights
class SinhalaPhonoNet(nn.Module):
# 🌟 num_classes=255 ලෙස සකසා ඇත
def __init__(self, base_model="facebook/wav2vec2-xls-r-300m", embedding_dim=256, num_classes=255):
super(SinhalaPhonoNet, self).__init__()
self.config = Wav2Vec2Config.from_pretrained(base_model, output_hidden_states=True)
self.backbone = Wav2Vec2Model.from_pretrained(base_model, config=self.config)
self.layer_weights = nn.Parameter(torch.ones(self.config.num_hidden_layers + 1))
self.attention = SelfAttentionPooling(self.config.hidden_size)
self.fc = nn.Sequential(
nn.Linear(self.config.hidden_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, embedding_dim),
nn.BatchNorm1d(embedding_dim)
)
self.classifier = nn.Linear(embedding_dim, num_classes)
def forward(self, input_values, attention_mask=None):
outputs = self.backbone(input_values=input_values, attention_mask=attention_mask)
stacked_hidden_states = torch.stack(outputs.hidden_states, dim=0)
weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
weighted_hidden_state = torch.sum(stacked_hidden_states * weights, dim=0)
pooled, _ = self.attention(weighted_hidden_state, attention_mask)
embeddings = self.fc(pooled)
# 🌟 Training එකේ වගේම අගයන් 3ක් Return කරයි
norm_embeddings = F.normalize(embeddings, p=2, dim=1)
logits = self.classifier(norm_embeddings)
return embeddings, norm_embeddings, logits
# ==========================================
# 2. මොඩලයන් පූරණය කිරීම (Hugging Face)
# ==========================================
DEVICE = torch.device("cpu")
BASE_MODEL_NAME = "facebook/wav2vec2-xls-r-300m"
PROCESSOR = Wav2Vec2FeatureExtractor.from_pretrained(BASE_MODEL_NAME)
REFERENCE_AUDIO_DIR = Path(__file__).resolve().parent / "reference_audios"
SAVED_STUDENT_AUDIO_DIR = Path(__file__).resolve().parent / "saved_student_audios"
UNSAFE_FILENAME_CHARS = '<>:"/\\|?*'
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
STUDENT_AUDIO_DATASET_REPO_ID = os.getenv("STUDENT_AUDIO_DATASET_REPO_ID")
STUDENT_AUDIO_UPLOAD_SUBDIR = os.getenv("STUDENT_AUDIO_UPLOAD_SUBDIR", "student_audios").strip("/")
def get_reference_audio_choices():
if not REFERENCE_AUDIO_DIR.exists():
return []
return [
(audio_path.stem, str(audio_path))
for audio_path in sorted(REFERENCE_AUDIO_DIR.glob("*.wav"), key=lambda path: path.stem)
]
def safe_filename_part(value, fallback="audio"):
cleaned = "".join("_" if char in UNSAFE_FILENAME_CHARS else char for char in str(value).strip())
cleaned = cleaned.strip(" .")
return cleaned or fallback
def save_successful_student_audio(student_audio, teacher_audio, verdict, accuracy):
source_path = Path(student_audio)
if not source_path.exists():
raise FileNotFoundError(f"Student audio file not found: {source_path}")
SAVED_STUDENT_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
teacher_name = safe_filename_part(Path(teacher_audio).stem, "teacher")
suffix = source_path.suffix or ".wav"
target_name = f"{timestamp}_{verdict.lower()}_{accuracy:.0f}_{teacher_name}{suffix}"
target_path = SAVED_STUDENT_AUDIO_DIR / target_name
shutil.copy2(source_path, target_path)
return target_path
def upload_student_audio_to_dataset(audio_path):
if not STUDENT_AUDIO_DATASET_REPO_ID or not HF_TOKEN:
return None
audio_path = Path(audio_path)
path_in_repo = audio_path.name
if STUDENT_AUDIO_UPLOAD_SUBDIR:
path_in_repo = f"{STUDENT_AUDIO_UPLOAD_SUBDIR}/{audio_path.name}"
api = HfApi(token=HF_TOKEN)
api.create_repo(
repo_id=STUDENT_AUDIO_DATASET_REPO_ID,
repo_type="dataset",
exist_ok=True,
)
api.upload_file(
path_or_fileobj=str(audio_path),
path_in_repo=path_in_repo,
repo_id=STUDENT_AUDIO_DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Upload {audio_path.name}",
)
quoted_path = "/".join(quote(part) for part in path_in_repo.split("/"))
return f"https://huggingface.co/datasets/{STUDENT_AUDIO_DATASET_REPO_ID}/tree/main/{quoted_path}"
REPO_ID = "TD-jayadeera/model_255"
MODEL_FILENAME= "SinhalaPhonoNet_Final_Checkpoint_v4.pth"
try:
print("⏳ Downloading & Loading Custom Model from Hugging Face...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
custom_model = SinhalaPhonoNet(num_classes=255).to(DEVICE)
# 🌟 Checkpoint එකෙන් මොළය පමණක් වෙන් කර ගැනීම
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
custom_model.load_state_dict(checkpoint['model_state_dict'])
custom_model.eval()
print("✅ Custom Model Loaded Successfully!")
except Exception as e:
print(f"❌ Error loading models: {e}")
# ==========================================
# 3. ප්‍රධාන Analysis Logic
# ==========================================
def process_audio(teacher_audio, student_audio):
if not teacher_audio and not student_audio:
return "කරුණාකර ගුරුවරයාගේ reference ශබ්දය තෝරා ඔබේ ශබ්දය ලබා දෙන්න.", {}
if not teacher_audio:
return "කරුණාකර ගුරුවරයාගේ reference ශබ්දයක් dropdown එකෙන් තෝරන්න.", {}
if not student_audio:
return "කරුණාකර ඔබේ ශබ්දය upload හෝ record කරන්න.", {}
try:
def get_emb(path):
speech, _ = librosa.load(path, sr=16000)
speech, _ = librosa.effects.trim(speech, top_db=25)
inputs = PROCESSOR(speech, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
# 🌟 අගයන් 3න් මැද අගය (Norm Embeddings) පමණක් ලබාගැනීම
_, emb, _ = custom_model(inputs.input_values, inputs.attention_mask)
return emb.cpu().numpy()
emb_t = get_emb(teacher_audio)
emb_s = get_emb(student_audio)
raw_dist = float(np.linalg.norm(emb_t - emb_s))
# =========================================================
# 🌟 අලුත් මොඩලයට ගැලපෙන සේ Calibration (Thresholds) වෙනස් කළා
# =========================================================
# 0.26 (Match) සහ 0.36 (Mismatch) අතර හරි මැද ලක්ෂ්‍යය
center_point = 0.31
# පරතරය කුඩා නිසා Sigmoid curve එකේ බෑවුම වැඩි කිරීම
steepness = 40
accuracy = (1 / (1 + math.exp(steepness * (raw_dist - center_point)))) * 100
# =========================================================
if accuracy >= 85:
verdict, color, msg = "EXCELLENT", "green", "ඉතාම නිවැරදියි! 🏆"
elif accuracy >= 65:
verdict, color, msg = "GOOD", "orange", "හොඳයි, තව උත්සාහ කරන්න! ⭐"
else:
verdict, color, msg = "INCORRECT", "red", "නැවත උත්සාහ කරන්න. ❌"
saved_audio_html = ""
if verdict in {"GOOD", "EXCELLENT"}:
try:
saved_audio_path = save_successful_student_audio(student_audio, teacher_audio, verdict, accuracy)
upload_message = ""
try:
uploaded_audio_url = upload_student_audio_to_dataset(saved_audio_path)
if uploaded_audio_url:
upload_message = f"<br>Uploaded to Dataset: <a href='{uploaded_audio_url}' target='_blank' rel='noopener noreferrer'>{STUDENT_AUDIO_DATASET_REPO_ID}</a>"
else:
upload_message = "<br>Dataset upload skipped: set HF_TOKEN and STUDENT_AUDIO_DATASET_REPO_ID secrets to persist on Hugging Face."
except Exception as upload_error:
upload_message = f"<br>Dataset upload warning: {upload_error}"
saved_audio_html = (
"<p style='font-size: 0.9em; color: #2f855a;'>"
f"Student audio saved: <b>{saved_audio_path.name}</b>"
f"{upload_message}"
"</p>"
)
except Exception as save_error:
saved_audio_html = (
"<p style='font-size: 0.9em; color: #c05621;'>"
f"Student audio save warning: {save_error}"
"</p>"
)
results_labels = {
"Excellent (ඉතා විශිෂ්ටයි)": 1.0 if verdict == "EXCELLENT" else 0.0,
"Good (හොඳයි)": 1.0 if verdict == "GOOD" else 0.0,
"Needs Work (නැවත උත්සාහ කරන්න)": 1.0 if verdict == "INCORRECT" else 0.0
}
model_type_str = "Custom SinhalaPhonoNet (255-Class)"
info_html = f"""
<div style='text-align: center; padding: 20px; border-radius: 10px; background-color: #f0f2f6; border: 2px solid {color};'>
<p style='color: #555; font-weight: bold;'>භාවිතා කළ මොඩලය: {model_type_str}</p>
<h2 style='color: {color}; margin-top: 0;'>{verdict}</h2>
<h3 style='color: #333;'>{msg}</h3>
<p style='font-size: 1.4em;'>නිරවද්‍යතාවය: <b>{accuracy:.2f}%</b></p>
<p style='font-size: 0.9em; color: #666;'>Raw Distance: {raw_dist:.4f}</p>
{saved_audio_html}
</div>
"""
return info_html, results_labels
except Exception as e:
return f"<p style='color:red;'>Error: {str(e)}</p>", {}
def analyze_custom(t, s): return process_audio(t, s)
# ==========================================
# 4. Gradio UI
# ==========================================
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ සිංහල මිතුරු (Sinhala Mithuru) - Pronunciation Lab")
gr.Markdown("පර්යේෂණ අරමුණු සඳහා මොඩලයන් දෙකෙහි වෙනස මෙතැනින් පරීක්ෂා කරන්න.")
with gr.Row():
with gr.Column(scale=1):
reference_audio_choices = get_reference_audio_choices()
t_input = gr.Dropdown(
choices=reference_audio_choices,
value=None,
label="ගුරුවරයාගේ ශබ්දය තෝරන්න (Teacher Reference)",
interactive=True,
)
s_input = gr.Audio(type="filepath", label="ඔබේ ශබ්දය (Student)")
btn_custom = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
result_html = gr.HTML(label="Result Status")
label_output = gr.Label(num_top_classes=1, label="Verdict Visualization")
btn_custom.click(fn=analyze_custom, inputs=[t_input, s_input], outputs=[result_html, label_output])
demo.launch()