255_docker_test / main.py
TD-jayadeera's picture
Update main.py
5c0fafe verified
import os
import math
import tempfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from huggingface_hub import hf_hub_download
# ==========================================
# 1. Global Configurations
# ==========================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_REPO_ID = "TD-jayadeera/model_255"
MODEL_FILENAME = "SinhalaPhonoNet_Final_Checkpoint_v4.pth"
# 🌟 ගුරුවරයාගේ ශබ්ද ගොනු ඇති Local ෆෝල්ඩරයේ නම (HF Space එකට මෙය Upload කළ යුතුයි)
REFERENCE_AUDIO_DIR = "reference_audios"
# ==========================================
# 2. Model Architecture (255-Class)
# ==========================================
class SelfAttentionPooling(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.W = nn.Linear(input_dim, 128)
self.V = nn.Linear(128, 1)
def forward(self, x, attention_mask=None):
scores = self.V(torch.tanh(self.W(x)))
if attention_mask is not None:
indices = torch.linspace(0, attention_mask.size(1)-1, steps=x.size(1)).long().to(x.device)
mask = torch.index_select(attention_mask, 1, indices).unsqueeze(-1)
scores = scores.masked_fill(mask == 0, -1e4)
attn_weights = F.softmax(scores, dim=1)
return torch.sum(x * attn_weights, dim=1), attn_weights
class SinhalaPhonoNet(nn.Module):
def __init__(self, base_model="facebook/wav2vec2-xls-r-300m", embedding_dim=256, num_classes=255):
super().__init__()
self.config = Wav2Vec2Config.from_pretrained(base_model, output_hidden_states=True)
self.backbone = Wav2Vec2Model.from_pretrained(base_model, config=self.config)
self.layer_weights = nn.Parameter(torch.ones(self.config.num_hidden_layers + 1))
self.attention = SelfAttentionPooling(self.config.hidden_size)
self.fc = nn.Sequential(
nn.Linear(self.config.hidden_size, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, embedding_dim),
nn.BatchNorm1d(embedding_dim)
)
self.classifier = nn.Linear(embedding_dim, num_classes)
def forward(self, input_values, attention_mask=None):
outputs = self.backbone(input_values=input_values, attention_mask=attention_mask)
stacked_hidden_states = torch.stack(outputs.hidden_states, dim=0)
weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
weighted_hidden_state = torch.sum(stacked_hidden_states * weights, dim=0)
pooled, _ = self.attention(weighted_hidden_state, attention_mask)
embeddings = self.fc(pooled)
norm_embeddings = F.normalize(embeddings, p=2, dim=1)
logits = self.classifier(norm_embeddings)
return embeddings, norm_embeddings, logits
# Global variables
model = None
processor = None
# ==========================================
# 3. Startup Event
# ==========================================
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, processor
print("⏳ Loading processor...")
processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")
print(f"⏳ Downloading & Loading custom model from HF ({MODEL_REPO_ID})...")
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
model = SinhalaPhonoNet(num_classes=255).to(DEVICE)
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"✅ SinhalaPhonoNet API Ready! (Accuracy: {checkpoint.get('best_val_acc', 0)*100:.2f}%)")
yield
print("🛑 Shutting down API...")
app = FastAPI(lifespan=lifespan, title="Sinhala Mithuru HF Space API")
# ==========================================
# 4. Core Logic Functions
# ==========================================
def get_embedding(audio_path):
speech, _ = librosa.load(audio_path, sr=16000)
speech, _ = librosa.effects.trim(speech, top_db=25)
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
_, norm_emb, _ = model(inputs.input_values.to(DEVICE), inputs.attention_mask.to(DEVICE))
return norm_emb.cpu().numpy()
# ==========================================
# 5. API Endpoints
# ==========================================
@app.get("/")
def read_root():
return {"status": "Online", "message": "SinhalaPhonoNet HF Space API is Running 🚀"}
@app.post("/analyze")
async def analyze_pronunciation(
target_audio_name: str = Form(...),
student_audio: UploadFile = File(...)
):
student_temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_student:
student_content = await student_audio.read()
temp_student.write(student_content)
student_temp_path = temp_student.name
teacher_audio_path = os.path.join(REFERENCE_AUDIO_DIR, target_audio_name)
if not os.path.exists(teacher_audio_path):
raise HTTPException(
status_code=404,
detail=f"Target audio '{target_audio_name}' not found in '{REFERENCE_AUDIO_DIR}' folder."
)
emb_teacher = get_embedding(teacher_audio_path)
emb_student = get_embedding(student_temp_path)
raw_dist = float(np.linalg.norm(emb_teacher - emb_student))
center_point = 0.31
steepness = 40
accuracy = (1 / (1 + math.exp(steepness * (raw_dist - center_point)))) * 100
if accuracy >= 85:
verdict = "EXCELLENT"
elif accuracy >= 65:
verdict = "GOOD"
else:
verdict = "INCORRECT"
return JSONResponse(content={
"target_word": target_audio_name,
"accuracy": round(accuracy, 2),
"raw_distance": round(raw_dist, 4),
"verdict": verdict
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if student_temp_path and os.path.exists(student_temp_path):
os.remove(student_temp_path)
# ==========================================
# 🌟 6. Hugging Face Space Uvicorn Runner
# ==========================================
if __name__ == "__main__":
import uvicorn
# Hugging Face Spaces require mapping to port 7860
uvicorn.run(app, host="0.0.0.0", port=7860)