import logging import os import torch import torchaudio import torch.nn.functional as F import numpy as np import onnxruntime as ort import soundfile as sf import subprocess from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware # Setup Logging logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger("LID_Engine") app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)") # CORS Fix app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Use Absolute Path (Success from previous log!) MODEL_PATH = "/app/local_model/pakistani_lid_v3.onnx" logger.info(f"🚀 Loading pre-baked ONNX model from: {MODEL_PATH}") try: session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) logger.info("✅ Engine is LIVE and Ready!") except Exception as e: logger.error(f"❌ Failed to load model: {e}") raise e labels = ("balochi", "english", "pashto", "sindhi", "urdu") id2label = {i: label for i, label in enumerate(labels)} def predict_audio(input_path): clean_wav_path = "cleaned_audio.wav" try: # 🛠️ THE FIX: Use FFmpeg to convert ANY format (WebM, OGG, etc.) to Standard WAV # This handles the "Format not recognised" error subprocess.run([ 'ffmpeg', '-y', '-i', input_path, '-ar', '16000', '-ac', '1', clean_wav_path ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # Now read the standard WAV data, sr = sf.read(clean_wav_path) waveform = torch.from_numpy(data).float() if waveform.ndim == 1: waveform = waveform.unsqueeze(0) # Audio Preprocessing target_frames = 16000 * 15 if waveform.shape[1] > target_frames: waveform = waveform[:, :target_frames] waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean() waveform = waveform / waveform.std().clamp(min=1e-6) length = waveform.shape[1] mask = torch.zeros(target_frames, dtype=torch.long) if length < target_frames: mask[:length] = 1 waveform = F.pad(waveform, (0, target_frames - length)) else: mask[:] = 1 # ONNX Inference ort_inputs = { "input_values": waveform.numpy(), "attention_mask": mask.unsqueeze(0).numpy() } logits = session.run(None, ort_inputs)[0] probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) pred_id = np.argmax(probs, axis=1)[0] if os.path.exists(clean_wav_path): os.remove(clean_wav_path) return id2label[pred_id], float(probs[0][pred_id]) except Exception as e: if os.path.exists(clean_wav_path): os.remove(clean_wav_path) raise e @app.post("/predict") async def predict(file: UploadFile = File(...)): temp_path = f"temp_{file.filename}" try: with open(temp_path, "wb") as f: f.write(await file.read()) lang, conf = predict_audio(temp_path) if os.path.exists(temp_path): os.remove(temp_path) return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)} except Exception as e: logger.error(f"Inference Error: {e}") if os.path.exists(temp_path): os.remove(temp_path) return {"success": False, "error": str(e)} @app.get("/") def health(): return {"status": "online"}