File size: 3,684 Bytes
360682b
c4b4df8
 
e853b4e
 
 
 
5003140
ab82fde
c4b4df8
 
e853b4e
f2e947b
 
360682b
e853b4e
c4b4df8
 
f2e947b
c4b4df8
 
 
 
 
 
 
e853b4e
ab82fde
5599463
 
ab82fde
360682b
ab82fde
f2e947b
360682b
ab82fde
360682b
e853b4e
 
 
 
ab82fde
 
 
 
 
 
 
 
 
e853b4e
ab82fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4b4df8
ab82fde
 
 
e853b4e
 
c4b4df8
 
e853b4e
c4b4df8
 
ab82fde
c4b4df8
ab82fde
 
c4b4df8
ab82fde
e853b4e
5599463
c4b4df8
 
 
 
f2e947b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import os
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
import onnxruntime as ort
import soundfile as sf
import subprocess
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware

# Setup Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("LID_Engine")

app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)")

# CORS Fix
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Use Absolute Path (Success from previous log!)
MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx"

logger.info(f"๐Ÿš€ Loading pre-baked ONNX model from: {MODEL_PATH}")
try:
    session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
    logger.info("โœ… Engine is LIVE and Ready!")
except Exception as e:
    logger.error(f"โŒ Failed to load model: {e}")
    raise e

labels = ("balochi", "english", "pashto", "sindhi", "urdu")
id2label = {i: label for i, label in enumerate(labels)}

def predict_audio(input_path):
    clean_wav_path = "cleaned_audio.wav"
    try:
        # ๐Ÿ› ๏ธ THE FIX: Use FFmpeg to convert ANY format (WebM, OGG, etc.) to Standard WAV
        # This handles the "Format not recognised" error
        subprocess.run([
            'ffmpeg', '-y', '-i', input_path, 
            '-ar', '16000', '-ac', '1', clean_wav_path
        ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        # Now read the standard WAV
        data, sr = sf.read(clean_wav_path)
        
        waveform = torch.from_numpy(data).float()
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)
        
        # Audio Preprocessing
        target_frames = 16000 * 15
        if waveform.shape[1] > target_frames:
            waveform = waveform[:, :target_frames]
        
        waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
        waveform = waveform / waveform.std().clamp(min=1e-6)
        
        length = waveform.shape[1]
        mask = torch.zeros(target_frames, dtype=torch.long)
        if length < target_frames:
            mask[:length] = 1
            waveform = F.pad(waveform, (0, target_frames - length))
        else:
            mask[:] = 1

        # ONNX Inference
        ort_inputs = {
            "input_values": waveform.numpy(),
            "attention_mask": mask.unsqueeze(0).numpy()
        }
        
        logits = session.run(None, ort_inputs)[0]
        probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
        pred_id = np.argmax(probs, axis=1)[0]
        
        if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
        return id2label[pred_id], float(probs[0][pred_id])
    
    except Exception as e:
        if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
        raise e

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    temp_path = f"temp_{file.filename}"
    try:
        with open(temp_path, "wb") as f:
            f.write(await file.read())
        
        lang, conf = predict_audio(temp_path)
        
        if os.path.exists(temp_path): os.remove(temp_path)
        return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
    
    except Exception as e:
        logger.error(f"Inference Error: {e}")
        if os.path.exists(temp_path): os.remove(temp_path)
        return {"success": False, "error": str(e)}

@app.get("/")
def health():
    return {"status": "online"}