Hammad712 commited on
Commit
e853b4e
·
verified ·
1 Parent(s): c6f287b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ import uvicorn
3
+ import torchaudio
4
+ import torch.nn.functional as F
5
+ import torch
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ from huggingface_hub import hf_hub_download
9
+ import os
10
+
11
+ app = FastAPI(title="Pakistani LID AI Engine (Standalone)")
12
+
13
+ print("📥 Checking/Downloading ONNX Model from Hugging Face...")
14
+ # Yeh line model ko cache kar legi, baar baar download nahi karegi
15
+ model_path = hf_hub_download(repo_id="Hammad712/pakistani-lid-v3-sota", filename="pakistani_lid_v3.onnx")
16
+
17
+ print("🚀 Loading ONNX Session for CPU...")
18
+ session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
19
+
20
+ labels = ("balochi", "english", "pashto", "sindhi", "urdu")
21
+ id2label = {i: label for i, label in enumerate(labels)}
22
+
23
+ def predict_audio(audio_path):
24
+ waveform, sr = torchaudio.load(audio_path)
25
+ if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True)
26
+ if waveform.ndim == 1: waveform = waveform.unsqueeze(0)
27
+
28
+ target_frames = int(sr * 15)
29
+ if waveform.shape[1] > target_frames: waveform = waveform[:, :target_frames]
30
+ if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000)
31
+
32
+ peak = waveform.abs().max().clamp(min=1e-6)
33
+ waveform = (waveform / peak) - waveform.mean()
34
+ waveform = waveform / waveform.std().clamp(min=1e-6)
35
+
36
+ length = waveform.shape[1]
37
+ mask = torch.zeros(16000 * 15, dtype=torch.long)
38
+ if length >= 16000 * 15:
39
+ waveform, mask[:] = waveform[:, :16000 * 15], 1
40
+ else:
41
+ mask[:length] = 1
42
+ waveform = F.pad(waveform, (0, 16000 * 15 - length))
43
+
44
+ ort_inputs = {
45
+ "input_values": waveform.numpy(),
46
+ "attention_mask": mask.unsqueeze(0).numpy()
47
+ }
48
+
49
+ logits = session.run(None, ort_inputs)[0]
50
+ exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
51
+ probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
52
+
53
+ pred_id = np.argmax(probs, axis=1)[0]
54
+ return id2label[pred_id], float(probs[0][pred_id])
55
+
56
+ @app.post("/predict")
57
+ async def predict_language(file: UploadFile = File(...)):
58
+ if not file.filename.endswith(('.wav', '.mp3', '.m4a', '.ogg')):
59
+ raise HTTPException(status_code=400, detail="Invalid audio format. Please upload wav, mp3, m4a, or ogg.")
60
+
61
+ temp_audio_path = f"temp_{file.filename}"
62
+ try:
63
+ # File save karein
64
+ with open(temp_audio_path, "wb") as buffer:
65
+ buffer.write(await file.read())
66
+
67
+ # Prediction lein
68
+ lang, confidence = predict_audio(temp_audio_path)
69
+
70
+ # Temp file delete kar dein
71
+ os.remove(temp_audio_path)
72
+
73
+ return {
74
+ "success": True,
75
+ "language": lang.upper(),
76
+ "confidence": round(confidence * 100, 2)
77
+ }
78
+ except Exception as e:
79
+ if os.path.exists(temp_audio_path):
80
+ os.remove(temp_audio_path)
81
+ raise HTTPException(status_code=500, detail=str(e))
82
+
83
+ if __name__ == "__main__":
84
+ print("✨ Server is LIVE at: http://localhost:8080")
85
+ uvicorn.run(app, host="0.0.0.0", port=8080)