Hammad712 commited on
Commit
c4b4df8
Β·
verified Β·
1 Parent(s): 18866cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -68
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import logging
2
- from fastapi import FastAPI, UploadFile, File, HTTPException
 
3
  import torchaudio
4
  import torch.nn.functional as F
5
- import torch
6
  import numpy as np
7
  import onnxruntime as ort
 
 
8
  from huggingface_hub import hf_hub_download
9
- import os
10
 
11
  # ==========================================
12
- # 1. Setup Production Logging
13
  # ==========================================
14
  logging.basicConfig(
15
  level=logging.INFO,
@@ -18,65 +19,67 @@ logging.basicConfig(
18
  )
19
  logger = logging.getLogger("LID_Engine")
20
 
21
- app = FastAPI(title="Pakistani LID AI Engine (Production)")
 
 
 
 
 
 
 
 
 
22
 
23
  # ==========================================
24
- # 2. Model Initialization (Fixing ONNX .data issue)
25
  # ==========================================
26
- logger.info("Initializing Application...")
27
  try:
28
- # Creating a local directory so ONNX doesn't get confused in HF hidden cache
29
  os.makedirs("local_model", exist_ok=True)
30
-
31
- logger.info("Downloading ONNX Data weights to local folder...")
32
- hf_hub_download(
33
- repo_id="Hammad712/pakistani-lid-v3-sota",
34
- filename="pakistani_lid_v3.onnx.data",
35
- local_dir="local_model"
36
- )
37
 
38
- logger.info("Downloading ONNX Structure to local folder...")
39
- hf_hub_download(
40
- repo_id="Hammad712/pakistani-lid-v3-sota",
41
- filename="pakistani_lid_v3.onnx",
42
- local_dir="local_model"
43
- )
44
 
45
- logger.info("Loading ONNX Session for CPU...")
46
- # Explicitly point to the local file we just downloaded
47
- local_model_path = os.path.join("local_model", "pakistani_lid_v3.onnx")
48
- session = ort.InferenceSession(local_model_path, providers=['CPUExecutionProvider'])
49
- logger.info("βœ… ONNX Session successfully loaded and ready!")
50
  except Exception as e:
51
- logger.error(f"❌ Failed to load model during startup: {e}")
52
  raise e
53
 
54
  labels = ("balochi", "english", "pashto", "sindhi", "urdu")
55
  id2label = {i: label for i, label in enumerate(labels)}
56
 
57
  # ==========================================
58
- # 3. Core Inference Logic
59
  # ==========================================
60
  def predict_audio(audio_path):
 
61
  waveform, sr = torchaudio.load(audio_path)
 
62
  if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True)
63
  if waveform.ndim == 1: waveform = waveform.unsqueeze(0)
64
 
65
- target_frames = int(sr * 15)
66
- if waveform.shape[1] > target_frames: waveform = waveform[:, :target_frames]
67
- if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000)
68
 
69
- peak = waveform.abs().max().clamp(min=1e-6)
70
- waveform = (waveform / peak) - waveform.mean()
 
 
 
71
  waveform = waveform / waveform.std().clamp(min=1e-6)
72
 
73
  length = waveform.shape[1]
74
- mask = torch.zeros(16000 * 15, dtype=torch.long)
75
- if length >= 16000 * 15:
76
- waveform, mask[:] = waveform[:, :16000 * 15], 1
77
- else:
78
  mask[:length] = 1
79
- waveform = F.pad(waveform, (0, 16000 * 15 - length))
 
 
80
 
81
  ort_inputs = {
82
  "input_values": waveform.numpy(),
@@ -84,44 +87,34 @@ def predict_audio(audio_path):
84
  }
85
 
86
  logits = session.run(None, ort_inputs)[0]
87
- exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
88
- probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
89
-
90
  pred_id = np.argmax(probs, axis=1)[0]
 
91
  return id2label[pred_id], float(probs[0][pred_id])
92
 
93
  # ==========================================
94
- # 4. API Endpoints
95
  # ==========================================
96
  @app.post("/predict")
97
- async def predict_language(file: UploadFile = File(...)):
98
- logger.info(f"Received request for file: {file.filename}")
 
99
 
100
- if not file.filename.endswith(('.wav', '.mp3', '.m4a', '.ogg')):
101
- logger.warning(f"Rejected invalid file type: {file.filename}")
102
- raise HTTPException(status_code=400, detail="Invalid audio format. Please upload wav, mp3, m4a, or ogg.")
103
-
104
- temp_audio_path = f"temp_{file.filename}"
105
  try:
106
- # Save file
107
- with open(temp_audio_path, "wb") as buffer:
108
- buffer.write(await file.read())
109
 
110
- # Predict
111
- logger.info(f"Processing inference for {file.filename}...")
112
- lang, confidence = predict_audio(temp_audio_path)
113
- logger.info(f"βœ… Prediction successful: {lang.upper()} ({confidence:.2%})")
114
 
115
- # Cleanup
116
- os.remove(temp_audio_path)
117
-
118
- return {
119
- "success": True,
120
- "language": lang.upper(),
121
- "confidence": round(confidence * 100, 2)
122
- }
123
  except Exception as e:
124
- logger.error(f"❌ Error processing {file.filename}: {str(e)}", exc_info=True)
125
- if os.path.exists(temp_audio_path):
126
- os.remove(temp_audio_path)
127
- raise HTTPException(status_code=500, detail="Internal Server Error")
 
 
 
 
1
  import logging
2
+ import os
3
+ import torch
4
  import torchaudio
5
  import torch.nn.functional as F
 
6
  import numpy as np
7
  import onnxruntime as ort
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException
9
+ from fastapi.middleware.cors import CORSMiddleware
10
  from huggingface_hub import hf_hub_download
 
11
 
12
  # ==========================================
13
+ # 1. Setup Logging
14
  # ==========================================
15
  logging.basicConfig(
16
  level=logging.INFO,
 
19
  )
20
  logger = logging.getLogger("LID_Engine")
21
 
22
+ app = FastAPI(title="Pakistani LID AI Engine (SOTA V3)")
23
+
24
+ # βœ… FIXING CORS: Taake aapka HTML frontend isay hit kar sakay
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
 
33
  # ==========================================
34
+ # 2. Model Initialization
35
  # ==========================================
36
+ logger.info("Initializing SOTA Engine...")
37
  try:
 
38
  os.makedirs("local_model", exist_ok=True)
 
 
 
 
 
 
 
39
 
40
+ # Download weights and structure
41
+ logger.info("Downloading ONNX files...")
42
+ hf_hub_download(repo_id="Hammad712/pakistani-lid-v3-sota", filename="pakistani_lid_v3.onnx.data", local_dir="local_model")
43
+ model_path = hf_hub_download(repo_id="Hammad712/pakistani-lid-v3-sota", filename="pakistani_lid_v3.onnx", local_dir="local_model")
 
 
44
 
45
+ # Load ONNX session
46
+ session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
47
+ logger.info("βœ… Model loaded successfully!")
 
 
48
  except Exception as e:
49
+ logger.error(f"❌ Initialization failed: {e}")
50
  raise e
51
 
52
  labels = ("balochi", "english", "pashto", "sindhi", "urdu")
53
  id2label = {i: label for i, label in enumerate(labels)}
54
 
55
  # ==========================================
56
+ # 3. Inference Logic
57
  # ==========================================
58
  def predict_audio(audio_path):
59
+ # Torchaudio loading with fallback logic
60
  waveform, sr = torchaudio.load(audio_path)
61
+
62
  if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True)
63
  if waveform.ndim == 1: waveform = waveform.unsqueeze(0)
64
 
65
+ # Resample and Preprocess
66
+ if sr != 16000:
67
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
68
 
69
+ target_frames = 16000 * 15
70
+ if waveform.shape[1] > target_frames:
71
+ waveform = waveform[:, :target_frames]
72
+
73
+ waveform = (waveform / waveform.abs().max().clamp(min=1e-6)) - waveform.mean()
74
  waveform = waveform / waveform.std().clamp(min=1e-6)
75
 
76
  length = waveform.shape[1]
77
+ mask = torch.zeros(target_frames, dtype=torch.long)
78
+ if length < target_frames:
 
 
79
  mask[:length] = 1
80
+ waveform = F.pad(waveform, (0, target_frames - length))
81
+ else:
82
+ mask[:] = 1
83
 
84
  ort_inputs = {
85
  "input_values": waveform.numpy(),
 
87
  }
88
 
89
  logits = session.run(None, ort_inputs)[0]
90
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
 
 
91
  pred_id = np.argmax(probs, axis=1)[0]
92
+
93
  return id2label[pred_id], float(probs[0][pred_id])
94
 
95
  # ==========================================
96
+ # 4. API Endpoint
97
  # ==========================================
98
  @app.post("/predict")
99
+ async def predict(file: UploadFile = File(...)):
100
+ logger.info(f"Inference request: {file.filename}")
101
+ temp_path = f"temp_{file.filename}"
102
 
 
 
 
 
 
103
  try:
104
+ with open(temp_path, "wb") as f:
105
+ f.write(await file.read())
 
106
 
107
+ lang, conf = predict_audio(temp_path)
108
+ os.remove(temp_path)
 
 
109
 
110
+ logger.info(f"Result: {lang} ({conf:.2%})")
111
+ return {"success": True, "language": lang.upper(), "confidence": round(conf * 100, 2)}
112
+
 
 
 
 
 
113
  except Exception as e:
114
+ logger.error(f"Prediction error: {e}")
115
+ if os.path.exists(temp_path): os.remove(temp_path)
116
+ return {"success": False, "error": str(e)}
117
+
118
+ @app.get("/")
119
+ def health_check():
120
+ return {"status": "online", "model": "Pakistani LID V3 SOTA"}