Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import librosa | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from contextlib import asynccontextmanager | |
| from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2FeatureExtractor | |
| from huggingface_hub import hf_hub_download | |
| # ========================================== | |
| # 1. Global Configurations | |
| # ========================================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_REPO_ID = "TD-jayadeera/model_255" | |
| MODEL_FILENAME = "SinhalaPhonoNet_Final_Checkpoint_v4.pth" | |
| # 🌟 ගුරුවරයාගේ ශබ්ද ගොනු ඇති Local ෆෝල්ඩරයේ නම (HF Space එකට මෙය Upload කළ යුතුයි) | |
| REFERENCE_AUDIO_DIR = "reference_audios" | |
| # ========================================== | |
| # 2. Model Architecture (255-Class) | |
| # ========================================== | |
| class SelfAttentionPooling(nn.Module): | |
| def __init__(self, input_dim): | |
| super().__init__() | |
| self.W = nn.Linear(input_dim, 128) | |
| self.V = nn.Linear(128, 1) | |
| def forward(self, x, attention_mask=None): | |
| scores = self.V(torch.tanh(self.W(x))) | |
| if attention_mask is not None: | |
| indices = torch.linspace(0, attention_mask.size(1)-1, steps=x.size(1)).long().to(x.device) | |
| mask = torch.index_select(attention_mask, 1, indices).unsqueeze(-1) | |
| scores = scores.masked_fill(mask == 0, -1e4) | |
| attn_weights = F.softmax(scores, dim=1) | |
| return torch.sum(x * attn_weights, dim=1), attn_weights | |
| class SinhalaPhonoNet(nn.Module): | |
| def __init__(self, base_model="facebook/wav2vec2-xls-r-300m", embedding_dim=256, num_classes=255): | |
| super().__init__() | |
| self.config = Wav2Vec2Config.from_pretrained(base_model, output_hidden_states=True) | |
| self.backbone = Wav2Vec2Model.from_pretrained(base_model, config=self.config) | |
| self.layer_weights = nn.Parameter(torch.ones(self.config.num_hidden_layers + 1)) | |
| self.attention = SelfAttentionPooling(self.config.hidden_size) | |
| self.fc = nn.Sequential( | |
| nn.Linear(self.config.hidden_size, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, embedding_dim), | |
| nn.BatchNorm1d(embedding_dim) | |
| ) | |
| self.classifier = nn.Linear(embedding_dim, num_classes) | |
| def forward(self, input_values, attention_mask=None): | |
| outputs = self.backbone(input_values=input_values, attention_mask=attention_mask) | |
| stacked_hidden_states = torch.stack(outputs.hidden_states, dim=0) | |
| weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1) | |
| weighted_hidden_state = torch.sum(stacked_hidden_states * weights, dim=0) | |
| pooled, _ = self.attention(weighted_hidden_state, attention_mask) | |
| embeddings = self.fc(pooled) | |
| norm_embeddings = F.normalize(embeddings, p=2, dim=1) | |
| logits = self.classifier(norm_embeddings) | |
| return embeddings, norm_embeddings, logits | |
| # Global variables | |
| model = None | |
| processor = None | |
| # ========================================== | |
| # 3. Startup Event | |
| # ========================================== | |
| async def lifespan(app: FastAPI): | |
| global model, processor | |
| print("⏳ Loading processor...") | |
| processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m") | |
| print(f"⏳ Downloading & Loading custom model from HF ({MODEL_REPO_ID})...") | |
| model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) | |
| model = SinhalaPhonoNet(num_classes=255).to(DEVICE) | |
| checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| print(f"✅ SinhalaPhonoNet API Ready! (Accuracy: {checkpoint.get('best_val_acc', 0)*100:.2f}%)") | |
| yield | |
| print("🛑 Shutting down API...") | |
| app = FastAPI(lifespan=lifespan, title="Sinhala Mithuru HF Space API") | |
| # ========================================== | |
| # 4. Core Logic Functions | |
| # ========================================== | |
| def get_embedding(audio_path): | |
| speech, _ = librosa.load(audio_path, sr=16000) | |
| speech, _ = librosa.effects.trim(speech, top_db=25) | |
| inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| _, norm_emb, _ = model(inputs.input_values.to(DEVICE), inputs.attention_mask.to(DEVICE)) | |
| return norm_emb.cpu().numpy() | |
| # ========================================== | |
| # 5. API Endpoints | |
| # ========================================== | |
| def read_root(): | |
| return {"status": "Online", "message": "SinhalaPhonoNet HF Space API is Running 🚀"} | |
| async def analyze_pronunciation( | |
| target_audio_name: str = Form(...), | |
| student_audio: UploadFile = File(...) | |
| ): | |
| student_temp_path = None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_student: | |
| student_content = await student_audio.read() | |
| temp_student.write(student_content) | |
| student_temp_path = temp_student.name | |
| teacher_audio_path = os.path.join(REFERENCE_AUDIO_DIR, target_audio_name) | |
| if not os.path.exists(teacher_audio_path): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Target audio '{target_audio_name}' not found in '{REFERENCE_AUDIO_DIR}' folder." | |
| ) | |
| emb_teacher = get_embedding(teacher_audio_path) | |
| emb_student = get_embedding(student_temp_path) | |
| raw_dist = float(np.linalg.norm(emb_teacher - emb_student)) | |
| center_point = 0.31 | |
| steepness = 40 | |
| accuracy = (1 / (1 + math.exp(steepness * (raw_dist - center_point)))) * 100 | |
| if accuracy >= 85: | |
| verdict = "EXCELLENT" | |
| elif accuracy >= 65: | |
| verdict = "GOOD" | |
| else: | |
| verdict = "INCORRECT" | |
| return JSONResponse(content={ | |
| "target_word": target_audio_name, | |
| "accuracy": round(accuracy, 2), | |
| "raw_distance": round(raw_dist, 4), | |
| "verdict": verdict | |
| }) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| if student_temp_path and os.path.exists(student_temp_path): | |
| os.remove(student_temp_path) | |
| # ========================================== | |
| # 🌟 6. Hugging Face Space Uvicorn Runner | |
| # ========================================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Hugging Face Spaces require mapping to port 7860 | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |