Hammad712 commited on
Commit
f2e947b
Β·
verified Β·
1 Parent(s): 5003140

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -49
app.py CHANGED
@@ -8,21 +8,14 @@ import onnxruntime as ort
8
  import soundfile as sf
9
  from fastapi import FastAPI, UploadFile, File, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from huggingface_hub import hf_hub_download
12
 
13
- # ==========================================
14
- # 1. Setup Logging
15
- # ==========================================
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format="%(asctime)s [%(levelname)s] %(message)s",
19
- handlers=[logging.StreamHandler()]
20
- )
21
  logger = logging.getLogger("LID_Engine")
22
 
23
  app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)")
24
 
25
- # βœ… CORS Fix for your HTML frontend
26
  app.add_middleware(
27
  CORSMiddleware,
28
  allow_origins=["*"],
@@ -31,54 +24,41 @@ app.add_middleware(
31
  allow_headers=["*"],
32
  )
33
 
34
- # ==========================================
35
- # 2. Model Initialization
36
- # ==========================================
37
- logger.info("Initializing SOTA Engine...")
 
38
  try:
39
- os.makedirs("local_model", exist_ok=True)
 
 
40
 
41
- logger.info("Downloading ONNX files...")
42
- hf_hub_download(repo_id="Hammad712/pakistani-lid-v3-sota", filename="pakistani_lid_v3.onnx.data", local_dir="local_model")
43
- model_path = hf_hub_download(repo_id="Hammad712/pakistani-lid-v3-sota", filename="pakistani_lid_v3.onnx", local_dir="local_model")
44
-
45
- session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
46
- logger.info("βœ… Model loaded successfully!")
47
  except Exception as e:
48
- logger.error(f"❌ Initialization failed: {e}")
49
  raise e
50
 
51
  labels = ("balochi", "english", "pashto", "sindhi", "urdu")
52
  id2label = {i: label for i, label in enumerate(labels)}
53
 
54
- # ==========================================
55
- # 3. Inference Logic
56
- # ==========================================
57
  def predict_audio(audio_path):
58
- # 🚨 Using Soundfile to avoid Torchaudio backend errors
59
  data, sr = sf.read(audio_path)
60
-
61
- # Convert to torch tensor [channels, samples]
62
  waveform = torch.from_numpy(data).float()
63
  if waveform.ndim == 2:
64
- waveform = waveform.T # soundfile uses [samples, channels]
65
- waveform = waveform.mean(dim=0, keepdim=True)
66
  else:
67
  waveform = waveform.unsqueeze(0)
68
 
69
- # Resample to 16kHz
70
  if sr != 16000:
71
  waveform = torchaudio.functional.resample(waveform, sr, 16000)
72
 
73
- # Normalize & Clip to 15s
74
  target_frames = 16000 * 15
75
- if waveform.shape[1] > target_frames:
76
- waveform = waveform[:, :target_frames]
77
-
78
  waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
79
  waveform = waveform / waveform.std().clamp(min=1e-6)
80
 
81
- # Create Mask
82
  length = waveform.shape[1]
83
  mask = torch.zeros(target_frames, dtype=torch.long)
84
  if length < target_frames:
@@ -87,7 +67,6 @@ def predict_audio(audio_path):
87
  else:
88
  mask[:] = 1
89
 
90
- # ONNX Inference
91
  ort_inputs = {
92
  "input_values": waveform.numpy(),
93
  "attention_mask": mask.unsqueeze(0).numpy()
@@ -99,29 +78,19 @@ def predict_audio(audio_path):
99
 
100
  return id2label[pred_id], float(probs[0][pred_id])
101
 
102
- # ==========================================
103
- # 4. API Endpoint
104
- # ==========================================
105
  @app.post("/predict")
106
  async def predict(file: UploadFile = File(...)):
107
- logger.info(f"Inference request: {file.filename}")
108
  temp_path = f"temp_{file.filename}"
109
-
110
  try:
111
  with open(temp_path, "wb") as f:
112
  f.write(await file.read())
113
-
114
  lang, conf = predict_audio(temp_path)
115
  os.remove(temp_path)
116
-
117
- logger.info(f"Result: {lang} ({conf:.2%})")
118
  return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
119
-
120
  except Exception as e:
121
- logger.error(f"Prediction error: {e}")
122
  if os.path.exists(temp_path): os.remove(temp_path)
123
  return {"success": False, "error": str(e)}
124
 
125
  @app.get("/")
126
- def health_check():
127
- return {"status": "online", "model": "Pakistani LID V3 SOTA"}
 
8
  import soundfile as sf
9
  from fastapi import FastAPI, UploadFile, File, HTTPException
10
  from fastapi.middleware.cors import CORSMiddleware
 
11
 
12
+ # Setup Logging
13
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
 
 
 
 
 
 
14
  logger = logging.getLogger("LID_Engine")
15
 
16
  app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)")
17
 
18
+ # CORS Fix
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
 
24
  allow_headers=["*"],
25
  )
26
 
27
+ # Load Model (Baked into the Docker image)
28
+ MODEL_DIR = "local_model"
29
+ MODEL_PATH = os.path.join(MODEL_DIR, "pakistani_lid_v3.onnx")
30
+
31
+ logger.info("πŸš€ Loading pre-baked ONNX model...")
32
  try:
33
+ # Check if files exist just in case
34
+ if not os.path.exists(MODEL_PATH):
35
+ raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
36
 
37
+ session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
38
+ logger.info("βœ… Engine is LIVE and Ready!")
 
 
 
 
39
  except Exception as e:
40
+ logger.error(f"❌ Failed to load model: {e}")
41
  raise e
42
 
43
  labels = ("balochi", "english", "pashto", "sindhi", "urdu")
44
  id2label = {i: label for i, label in enumerate(labels)}
45
 
 
 
 
46
  def predict_audio(audio_path):
 
47
  data, sr = sf.read(audio_path)
 
 
48
  waveform = torch.from_numpy(data).float()
49
  if waveform.ndim == 2:
50
+ waveform = waveform.T.mean(dim=0, keepdim=True)
 
51
  else:
52
  waveform = waveform.unsqueeze(0)
53
 
 
54
  if sr != 16000:
55
  waveform = torchaudio.functional.resample(waveform, sr, 16000)
56
 
 
57
  target_frames = 16000 * 15
58
+ waveform = waveform[:, :target_frames]
 
 
59
  waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
60
  waveform = waveform / waveform.std().clamp(min=1e-6)
61
 
 
62
  length = waveform.shape[1]
63
  mask = torch.zeros(target_frames, dtype=torch.long)
64
  if length < target_frames:
 
67
  else:
68
  mask[:] = 1
69
 
 
70
  ort_inputs = {
71
  "input_values": waveform.numpy(),
72
  "attention_mask": mask.unsqueeze(0).numpy()
 
78
 
79
  return id2label[pred_id], float(probs[0][pred_id])
80
 
 
 
 
81
  @app.post("/predict")
82
  async def predict(file: UploadFile = File(...)):
 
83
  temp_path = f"temp_{file.filename}"
 
84
  try:
85
  with open(temp_path, "wb") as f:
86
  f.write(await file.read())
 
87
  lang, conf = predict_audio(temp_path)
88
  os.remove(temp_path)
 
 
89
  return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
 
90
  except Exception as e:
 
91
  if os.path.exists(temp_path): os.remove(temp_path)
92
  return {"success": False, "error": str(e)}
93
 
94
  @app.get("/")
95
+ def health():
96
+ return {"status": "online"}