Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import librosa | |
| import numpy as np | |
| import math | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from pathlib import Path | |
| from urllib.parse import quote | |
| from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2FeatureExtractor | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # ========================================== | |
| # 1. මොඩලයේ ව්යුහය (Architecture) | |
| # ========================================== | |
| class SelfAttentionPooling(nn.Module): | |
| def __init__(self, input_dim): | |
| super(SelfAttentionPooling, self).__init__() | |
| self.W = nn.Linear(input_dim, 128) | |
| self.V = nn.Linear(128, 1) | |
| def forward(self, x, attention_mask=None): | |
| scores = self.V(torch.tanh(self.W(x))) | |
| if attention_mask is not None: | |
| indices = torch.linspace(0, attention_mask.size(1) - 1, steps=x.size(1)).long().to(x.device) | |
| mask = torch.index_select(attention_mask, 1, indices).unsqueeze(-1) | |
| scores = scores.masked_fill(mask == 0, -1e4) | |
| attn_weights = F.softmax(scores, dim=1) | |
| return torch.sum(x * attn_weights, dim=1), attn_weights | |
| class SinhalaPhonoNet(nn.Module): | |
| # 🌟 num_classes=255 ලෙස සකසා ඇත | |
| def __init__(self, base_model="facebook/wav2vec2-xls-r-300m", embedding_dim=256, num_classes=255): | |
| super(SinhalaPhonoNet, self).__init__() | |
| self.config = Wav2Vec2Config.from_pretrained(base_model, output_hidden_states=True) | |
| self.backbone = Wav2Vec2Model.from_pretrained(base_model, config=self.config) | |
| self.layer_weights = nn.Parameter(torch.ones(self.config.num_hidden_layers + 1)) | |
| self.attention = SelfAttentionPooling(self.config.hidden_size) | |
| self.fc = nn.Sequential( | |
| nn.Linear(self.config.hidden_size, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, embedding_dim), | |
| nn.BatchNorm1d(embedding_dim) | |
| ) | |
| self.classifier = nn.Linear(embedding_dim, num_classes) | |
| def forward(self, input_values, attention_mask=None): | |
| outputs = self.backbone(input_values=input_values, attention_mask=attention_mask) | |
| stacked_hidden_states = torch.stack(outputs.hidden_states, dim=0) | |
| weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1) | |
| weighted_hidden_state = torch.sum(stacked_hidden_states * weights, dim=0) | |
| pooled, _ = self.attention(weighted_hidden_state, attention_mask) | |
| embeddings = self.fc(pooled) | |
| # 🌟 Training එකේ වගේම අගයන් 3ක් Return කරයි | |
| norm_embeddings = F.normalize(embeddings, p=2, dim=1) | |
| logits = self.classifier(norm_embeddings) | |
| return embeddings, norm_embeddings, logits | |
| # ========================================== | |
| # 2. මොඩලයන් පූරණය කිරීම (Hugging Face) | |
| # ========================================== | |
| DEVICE = torch.device("cpu") | |
| BASE_MODEL_NAME = "facebook/wav2vec2-xls-r-300m" | |
| PROCESSOR = Wav2Vec2FeatureExtractor.from_pretrained(BASE_MODEL_NAME) | |
| REFERENCE_AUDIO_DIR = Path(__file__).resolve().parent / "reference_audios" | |
| SAVED_STUDENT_AUDIO_DIR = Path(__file__).resolve().parent / "saved_student_audios" | |
| UNSAFE_FILENAME_CHARS = '<>:"/\\|?*' | |
| HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| STUDENT_AUDIO_DATASET_REPO_ID = os.getenv("STUDENT_AUDIO_DATASET_REPO_ID") | |
| STUDENT_AUDIO_UPLOAD_SUBDIR = os.getenv("STUDENT_AUDIO_UPLOAD_SUBDIR", "student_audios").strip("/") | |
| def get_reference_audio_choices(): | |
| if not REFERENCE_AUDIO_DIR.exists(): | |
| return [] | |
| return [ | |
| (audio_path.stem, str(audio_path)) | |
| for audio_path in sorted(REFERENCE_AUDIO_DIR.glob("*.wav"), key=lambda path: path.stem) | |
| ] | |
| def safe_filename_part(value, fallback="audio"): | |
| cleaned = "".join("_" if char in UNSAFE_FILENAME_CHARS else char for char in str(value).strip()) | |
| cleaned = cleaned.strip(" .") | |
| return cleaned or fallback | |
| def save_successful_student_audio(student_audio, teacher_audio, verdict, accuracy): | |
| source_path = Path(student_audio) | |
| if not source_path.exists(): | |
| raise FileNotFoundError(f"Student audio file not found: {source_path}") | |
| SAVED_STUDENT_AUDIO_DIR.mkdir(parents=True, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| teacher_name = safe_filename_part(Path(teacher_audio).stem, "teacher") | |
| suffix = source_path.suffix or ".wav" | |
| target_name = f"{timestamp}_{verdict.lower()}_{accuracy:.0f}_{teacher_name}{suffix}" | |
| target_path = SAVED_STUDENT_AUDIO_DIR / target_name | |
| shutil.copy2(source_path, target_path) | |
| return target_path | |
| def upload_student_audio_to_dataset(audio_path): | |
| if not STUDENT_AUDIO_DATASET_REPO_ID or not HF_TOKEN: | |
| return None | |
| audio_path = Path(audio_path) | |
| path_in_repo = audio_path.name | |
| if STUDENT_AUDIO_UPLOAD_SUBDIR: | |
| path_in_repo = f"{STUDENT_AUDIO_UPLOAD_SUBDIR}/{audio_path.name}" | |
| api = HfApi(token=HF_TOKEN) | |
| api.create_repo( | |
| repo_id=STUDENT_AUDIO_DATASET_REPO_ID, | |
| repo_type="dataset", | |
| exist_ok=True, | |
| ) | |
| api.upload_file( | |
| path_or_fileobj=str(audio_path), | |
| path_in_repo=path_in_repo, | |
| repo_id=STUDENT_AUDIO_DATASET_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"Upload {audio_path.name}", | |
| ) | |
| quoted_path = "/".join(quote(part) for part in path_in_repo.split("/")) | |
| return f"https://huggingface.co/datasets/{STUDENT_AUDIO_DATASET_REPO_ID}/tree/main/{quoted_path}" | |
| REPO_ID = "TD-jayadeera/model_255" | |
| MODEL_FILENAME= "SinhalaPhonoNet_Final_Checkpoint_v4.pth" | |
| try: | |
| print("⏳ Downloading & Loading Custom Model from Hugging Face...") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) | |
| custom_model = SinhalaPhonoNet(num_classes=255).to(DEVICE) | |
| # 🌟 Checkpoint එකෙන් මොළය පමණක් වෙන් කර ගැනීම | |
| checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) | |
| custom_model.load_state_dict(checkpoint['model_state_dict']) | |
| custom_model.eval() | |
| print("✅ Custom Model Loaded Successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading models: {e}") | |
| # ========================================== | |
| # 3. ප්රධාන Analysis Logic | |
| # ========================================== | |
| def process_audio(teacher_audio, student_audio): | |
| if not teacher_audio and not student_audio: | |
| return "කරුණාකර ගුරුවරයාගේ reference ශබ්දය තෝරා ඔබේ ශබ්දය ලබා දෙන්න.", {} | |
| if not teacher_audio: | |
| return "කරුණාකර ගුරුවරයාගේ reference ශබ්දයක් dropdown එකෙන් තෝරන්න.", {} | |
| if not student_audio: | |
| return "කරුණාකර ඔබේ ශබ්දය upload හෝ record කරන්න.", {} | |
| try: | |
| def get_emb(path): | |
| speech, _ = librosa.load(path, sr=16000) | |
| speech, _ = librosa.effects.trim(speech, top_db=25) | |
| inputs = PROCESSOR(speech, sampling_rate=16000, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| # 🌟 අගයන් 3න් මැද අගය (Norm Embeddings) පමණක් ලබාගැනීම | |
| _, emb, _ = custom_model(inputs.input_values, inputs.attention_mask) | |
| return emb.cpu().numpy() | |
| emb_t = get_emb(teacher_audio) | |
| emb_s = get_emb(student_audio) | |
| raw_dist = float(np.linalg.norm(emb_t - emb_s)) | |
| # ========================================================= | |
| # 🌟 අලුත් මොඩලයට ගැලපෙන සේ Calibration (Thresholds) වෙනස් කළා | |
| # ========================================================= | |
| # 0.26 (Match) සහ 0.36 (Mismatch) අතර හරි මැද ලක්ෂ්යය | |
| center_point = 0.31 | |
| # පරතරය කුඩා නිසා Sigmoid curve එකේ බෑවුම වැඩි කිරීම | |
| steepness = 40 | |
| accuracy = (1 / (1 + math.exp(steepness * (raw_dist - center_point)))) * 100 | |
| # ========================================================= | |
| if accuracy >= 85: | |
| verdict, color, msg = "EXCELLENT", "green", "ඉතාම නිවැරදියි! 🏆" | |
| elif accuracy >= 65: | |
| verdict, color, msg = "GOOD", "orange", "හොඳයි, තව උත්සාහ කරන්න! ⭐" | |
| else: | |
| verdict, color, msg = "INCORRECT", "red", "නැවත උත්සාහ කරන්න. ❌" | |
| saved_audio_html = "" | |
| if verdict in {"GOOD", "EXCELLENT"}: | |
| try: | |
| saved_audio_path = save_successful_student_audio(student_audio, teacher_audio, verdict, accuracy) | |
| upload_message = "" | |
| try: | |
| uploaded_audio_url = upload_student_audio_to_dataset(saved_audio_path) | |
| if uploaded_audio_url: | |
| upload_message = f"<br>Uploaded to Dataset: <a href='{uploaded_audio_url}' target='_blank' rel='noopener noreferrer'>{STUDENT_AUDIO_DATASET_REPO_ID}</a>" | |
| else: | |
| upload_message = "<br>Dataset upload skipped: set HF_TOKEN and STUDENT_AUDIO_DATASET_REPO_ID secrets to persist on Hugging Face." | |
| except Exception as upload_error: | |
| upload_message = f"<br>Dataset upload warning: {upload_error}" | |
| saved_audio_html = ( | |
| "<p style='font-size: 0.9em; color: #2f855a;'>" | |
| f"Student audio saved: <b>{saved_audio_path.name}</b>" | |
| f"{upload_message}" | |
| "</p>" | |
| ) | |
| except Exception as save_error: | |
| saved_audio_html = ( | |
| "<p style='font-size: 0.9em; color: #c05621;'>" | |
| f"Student audio save warning: {save_error}" | |
| "</p>" | |
| ) | |
| results_labels = { | |
| "Excellent (ඉතා විශිෂ්ටයි)": 1.0 if verdict == "EXCELLENT" else 0.0, | |
| "Good (හොඳයි)": 1.0 if verdict == "GOOD" else 0.0, | |
| "Needs Work (නැවත උත්සාහ කරන්න)": 1.0 if verdict == "INCORRECT" else 0.0 | |
| } | |
| model_type_str = "Custom SinhalaPhonoNet (255-Class)" | |
| info_html = f""" | |
| <div style='text-align: center; padding: 20px; border-radius: 10px; background-color: #f0f2f6; border: 2px solid {color};'> | |
| <p style='color: #555; font-weight: bold;'>භාවිතා කළ මොඩලය: {model_type_str}</p> | |
| <h2 style='color: {color}; margin-top: 0;'>{verdict}</h2> | |
| <h3 style='color: #333;'>{msg}</h3> | |
| <p style='font-size: 1.4em;'>නිරවද්යතාවය: <b>{accuracy:.2f}%</b></p> | |
| <p style='font-size: 0.9em; color: #666;'>Raw Distance: {raw_dist:.4f}</p> | |
| {saved_audio_html} | |
| </div> | |
| """ | |
| return info_html, results_labels | |
| except Exception as e: | |
| return f"<p style='color:red;'>Error: {str(e)}</p>", {} | |
| def analyze_custom(t, s): return process_audio(t, s) | |
| # ========================================== | |
| # 4. Gradio UI | |
| # ========================================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🎙️ සිංහල මිතුරු (Sinhala Mithuru) - Pronunciation Lab") | |
| gr.Markdown("පර්යේෂණ අරමුණු සඳහා මොඩලයන් දෙකෙහි වෙනස මෙතැනින් පරීක්ෂා කරන්න.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| reference_audio_choices = get_reference_audio_choices() | |
| t_input = gr.Dropdown( | |
| choices=reference_audio_choices, | |
| value=None, | |
| label="ගුරුවරයාගේ ශබ්දය තෝරන්න (Teacher Reference)", | |
| interactive=True, | |
| ) | |
| s_input = gr.Audio(type="filepath", label="ඔබේ ශබ්දය (Student)") | |
| btn_custom = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=1): | |
| result_html = gr.HTML(label="Result Status") | |
| label_output = gr.Label(num_top_classes=1, label="Verdict Visualization") | |
| btn_custom.click(fn=analyze_custom, inputs=[t_input, s_input], outputs=[result_html, label_output]) | |
| demo.launch() | |