File size: 3,684 Bytes
360682b c4b4df8 e853b4e 5003140 ab82fde c4b4df8 e853b4e f2e947b 360682b e853b4e c4b4df8 f2e947b c4b4df8 e853b4e ab82fde 5599463 ab82fde 360682b ab82fde f2e947b 360682b ab82fde 360682b e853b4e ab82fde e853b4e ab82fde c4b4df8 ab82fde e853b4e c4b4df8 e853b4e c4b4df8 ab82fde c4b4df8 ab82fde c4b4df8 ab82fde e853b4e 5599463 c4b4df8 f2e947b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import logging
import os
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
import onnxruntime as ort
import soundfile as sf
import subprocess
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Setup Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("LID_Engine")
app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)")
# CORS Fix
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Use Absolute Path (Success from previous log!)
MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx"
logger.info(f"๐ Loading pre-baked ONNX model from: {MODEL_PATH}")
try:
session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
logger.info("โ
Engine is LIVE and Ready!")
except Exception as e:
logger.error(f"โ Failed to load model: {e}")
raise e
labels = ("balochi", "english", "pashto", "sindhi", "urdu")
id2label = {i: label for i, label in enumerate(labels)}
def predict_audio(input_path):
clean_wav_path = "cleaned_audio.wav"
try:
# ๐ ๏ธ THE FIX: Use FFmpeg to convert ANY format (WebM, OGG, etc.) to Standard WAV
# This handles the "Format not recognised" error
subprocess.run([
'ffmpeg', '-y', '-i', input_path,
'-ar', '16000', '-ac', '1', clean_wav_path
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Now read the standard WAV
data, sr = sf.read(clean_wav_path)
waveform = torch.from_numpy(data).float()
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Audio Preprocessing
target_frames = 16000 * 15
if waveform.shape[1] > target_frames:
waveform = waveform[:, :target_frames]
waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
waveform = waveform / waveform.std().clamp(min=1e-6)
length = waveform.shape[1]
mask = torch.zeros(target_frames, dtype=torch.long)
if length < target_frames:
mask[:length] = 1
waveform = F.pad(waveform, (0, target_frames - length))
else:
mask[:] = 1
# ONNX Inference
ort_inputs = {
"input_values": waveform.numpy(),
"attention_mask": mask.unsqueeze(0).numpy()
}
logits = session.run(None, ort_inputs)[0]
probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
pred_id = np.argmax(probs, axis=1)[0]
if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
return id2label[pred_id], float(probs[0][pred_id])
except Exception as e:
if os.path.exists(clean_wav_path): os.remove(clean_wav_path)
raise e
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
temp_path = f"temp_{file.filename}"
try:
with open(temp_path, "wb") as f:
f.write(await file.read())
lang, conf = predict_audio(temp_path)
if os.path.exists(temp_path): os.remove(temp_path)
return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
except Exception as e:
logger.error(f"Inference Error: {e}")
if os.path.exists(temp_path): os.remove(temp_path)
return {"success": False, "error": str(e)}
@app.get("/")
def health():
return {"status": "online"} |