File size: 6,944 Bytes
1e734dc
 
 
 
 
 
 
 
 
 
 
 
 
 
5c0fafe
 
 
1e734dc
 
 
 
 
5c0fafe
1e734dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c0fafe
1e734dc
 
 
 
 
81e7c74
 
 
 
1e734dc
 
 
 
 
 
 
 
 
 
 
 
5c0fafe
1e734dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c0fafe
1e734dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c0fafe
1e734dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c0fafe
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import math
import tempfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from huggingface_hub import hf_hub_download

# ==========================================
# 1. Global Configurations
# ==========================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_REPO_ID = "TD-jayadeera/model_255" 
MODEL_FILENAME = "SinhalaPhonoNet_Final_Checkpoint_v4.pth"

# 🌟 ගුරුවරයාගේ ශබ්ද ගොනු ඇති Local ෆෝල්ඩරයේ නම (HF Space එකට මෙය Upload කළ යුතුයි)
REFERENCE_AUDIO_DIR = "reference_audios"

# ==========================================
# 2. Model Architecture (255-Class)
# ==========================================
class SelfAttentionPooling(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.W = nn.Linear(input_dim, 128)
        self.V = nn.Linear(128, 1)

    def forward(self, x, attention_mask=None):
        scores = self.V(torch.tanh(self.W(x)))
        if attention_mask is not None:
            indices = torch.linspace(0, attention_mask.size(1)-1, steps=x.size(1)).long().to(x.device)
            mask = torch.index_select(attention_mask, 1, indices).unsqueeze(-1)
            scores = scores.masked_fill(mask == 0, -1e4)
        attn_weights = F.softmax(scores, dim=1)
        return torch.sum(x * attn_weights, dim=1), attn_weights

class SinhalaPhonoNet(nn.Module):
    def __init__(self, base_model="facebook/wav2vec2-xls-r-300m", embedding_dim=256, num_classes=255):
        super().__init__()
        self.config = Wav2Vec2Config.from_pretrained(base_model, output_hidden_states=True)
        self.backbone = Wav2Vec2Model.from_pretrained(base_model, config=self.config)
        self.layer_weights = nn.Parameter(torch.ones(self.config.num_hidden_layers + 1))
        self.attention = SelfAttentionPooling(self.config.hidden_size)
        self.fc = nn.Sequential(
            nn.Linear(self.config.hidden_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, embedding_dim),
            nn.BatchNorm1d(embedding_dim)
        )
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, input_values, attention_mask=None):
        outputs = self.backbone(input_values=input_values, attention_mask=attention_mask)
        stacked_hidden_states = torch.stack(outputs.hidden_states, dim=0)
        weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
        weighted_hidden_state = torch.sum(stacked_hidden_states * weights, dim=0)
        pooled, _ = self.attention(weighted_hidden_state, attention_mask)
        embeddings = self.fc(pooled)
        norm_embeddings = F.normalize(embeddings, p=2, dim=1)
        logits = self.classifier(norm_embeddings)
        return embeddings, norm_embeddings, logits

# Global variables
model = None
processor = None

# ==========================================
# 3. Startup Event
# ==========================================
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, processor
    print("⏳ Loading processor...")

    hf_token = os.getenv("HF_TOKEN")
    
    processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m",token=hf_token,resume_download=True)
    
    print(f"⏳ Downloading & Loading custom model from HF ({MODEL_REPO_ID})...")
    model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
    
    model = SinhalaPhonoNet(num_classes=255).to(DEVICE)
    checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f"✅ SinhalaPhonoNet API Ready! (Accuracy: {checkpoint.get('best_val_acc', 0)*100:.2f}%)")
    yield
    print("🛑 Shutting down API...")

app = FastAPI(lifespan=lifespan, title="Sinhala Mithuru HF Space API")

# ==========================================
# 4. Core Logic Functions
# ==========================================
def get_embedding(audio_path):
    speech, _ = librosa.load(audio_path, sr=16000)
    speech, _ = librosa.effects.trim(speech, top_db=25)
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        _, norm_emb, _ = model(inputs.input_values.to(DEVICE), inputs.attention_mask.to(DEVICE))
    return norm_emb.cpu().numpy()

# ==========================================
# 5. API Endpoints
# ==========================================
@app.get("/")
def read_root():
    return {"status": "Online", "message": "SinhalaPhonoNet HF Space API is Running 🚀"}

@app.post("/analyze")
async def analyze_pronunciation(
    target_audio_name: str = Form(...), 
    student_audio: UploadFile = File(...)
):
    student_temp_path = None
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_student:
            student_content = await student_audio.read()
            temp_student.write(student_content)
            student_temp_path = temp_student.name

        teacher_audio_path = os.path.join(REFERENCE_AUDIO_DIR, target_audio_name)
        
        if not os.path.exists(teacher_audio_path):
            raise HTTPException(
                status_code=404, 
                detail=f"Target audio '{target_audio_name}' not found in '{REFERENCE_AUDIO_DIR}' folder."
            )

        emb_teacher = get_embedding(teacher_audio_path)
        emb_student = get_embedding(student_temp_path)
        raw_dist = float(np.linalg.norm(emb_teacher - emb_student))

        center_point = 0.31
        steepness = 40
        accuracy = (1 / (1 + math.exp(steepness * (raw_dist - center_point)))) * 100

        if accuracy >= 85:
            verdict = "EXCELLENT"
        elif accuracy >= 65:
            verdict = "GOOD"
        else:
            verdict = "INCORRECT"

        return JSONResponse(content={
            "target_word": target_audio_name,
            "accuracy": round(accuracy, 2),
            "raw_distance": round(raw_dist, 4),
            "verdict": verdict
        })

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    
    finally:
        if student_temp_path and os.path.exists(student_temp_path):
            os.remove(student_temp_path)

# ==========================================
# 🌟 6. Hugging Face Space Uvicorn Runner
# ==========================================
if __name__ == "__main__":
    import uvicorn
    # Hugging Face Spaces require mapping to port 7860
    uvicorn.run(app, host="0.0.0.0", port=7860)