HackerMOne commited on
Commit
695193c
·
verified ·
1 Parent(s): 88d16a3

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +27 -0
  2. app.py +134 -0
  3. requirements.txt +10 -0
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 as the base image
2
+ FROM python:3.9
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies (FFmpeg is required for audio processing)
8
+ RUN apt-get update && apt-get install -y ffmpeg
9
+
10
+ # Copy requirements and install Python dependencies
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy the rest of the application code
15
+ COPY . .
16
+
17
+ # Create a cache directory for Hugging Face models and set permissions
18
+ # This prevents permission errors when the model tries to download
19
+ RUN mkdir -p /app/cache && chmod 777 /app/cache
20
+ ENV TRANSFORMERS_CACHE=/app/cache
21
+ ENV HF_HOME=/app/cache
22
+
23
+ # Expose the port (Hugging Face Spaces uses 7860)
24
+ EXPOSE 7860
25
+
26
+ # Command to run the application using Uvicorn
27
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from fastapi import FastAPI, File, UploadFile, Form
6
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
7
+ from typing import Optional
8
+ import Levenshtein
9
+
10
+ app = FastAPI()
11
+
12
+ # --- CONFIGURATION ---
13
+ # Using the 300m model for a balance of speed and Indian language support.
14
+ MODEL_ID = "facebook/mms-300m"
15
+
16
+ print(f"🔄 Loading AI Model: {MODEL_ID}...")
17
+ try:
18
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
19
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
20
+ print("✅ Model loaded successfully!")
21
+ except Exception as e:
22
+ print(f"❌ Failed to load model: {e}")
23
+ raise e
24
+
25
+ # Language Code Mapping (Must match your Django app's expectations)
26
+ LANG_MAP = {
27
+ 'hindi': 'hin', 'tamil': 'tam', 'telugu': 'tel', 'marathi': 'mar',
28
+ 'bengali': 'ben', 'gujarati': 'guj', 'kannada': 'kan', 'malayalam': 'mal',
29
+ 'punjabi': 'pan', 'urdu': 'urd', 'assamese': 'asm', 'odia': 'ory',
30
+ 'english': 'eng'
31
+ }
32
+
33
+ @app.get("/")
34
+ def home():
35
+ return {"status": "running", "service": "SLAQ AI Engine", "model": MODEL_ID}
36
+
37
+ @app.get("/health")
38
+ def health():
39
+ return {"status": "healthy"}
40
+
41
+ @app.post("/analyze")
42
+ async def analyze_audio(
43
+ audio: UploadFile = File(...),
44
+ transcript: Optional[str] = Form(""),
45
+ language: Optional[str] = Form("eng")
46
+ ):
47
+ print(f"📥 Received analysis request. Language: {language}")
48
+ temp_filename = f"temp_{audio.filename}"
49
+
50
+ try:
51
+ # 1. Save uploaded file temporarily
52
+ with open(temp_filename, "wb") as buffer:
53
+ buffer.write(await audio.read())
54
+
55
+ # 2. Load and resample audio (16kHz required for Wav2Vec2)
56
+ speech, sr = librosa.load(temp_filename, sr=16000)
57
+
58
+ # 3. Configure Language Adapter
59
+ target_lang = LANG_MAP.get(str(language).lower(), 'eng')
60
+
61
+ try:
62
+ # MMS requires loading the specific language adapter
63
+ processor.tokenizer.set_target_lang(target_lang)
64
+ model.load_adapter(target_lang)
65
+ except Exception as e:
66
+ print(f"⚠️ Language adapter error for '{target_lang}': {e}. Falling back to English.")
67
+ target_lang = 'eng'
68
+ processor.tokenizer.set_target_lang('eng')
69
+ model.load_adapter('eng')
70
+
71
+ # 4. Run Inference (The AI part)
72
+ inputs = processor(speech, sampling_rate=16000, return_tensors="pt")
73
+
74
+ with torch.no_grad():
75
+ outputs = model(**inputs)
76
+ logits = outputs.logits
77
+
78
+ # Decode the output to text
79
+ predicted_ids = torch.argmax(logits, dim=-1)
80
+ actual_transcript = processor.batch_decode(predicted_ids)[0]
81
+ print(f"📝 Transcribed: {actual_transcript[:50]}...")
82
+
83
+ # 5. Calculate Metrics
84
+ confidence = float(torch.mean(torch.nn.functional.softmax(logits, dim=-1).max(dim=-1).values))
85
+
86
+ mismatched_chars = []
87
+ mismatch_pct = 0.0
88
+
89
+ # Calculate mismatch if a target transcript was provided
90
+ if transcript:
91
+ dist = Levenshtein.distance(actual_transcript, transcript)
92
+ max_len = max(len(transcript), 1)
93
+ mismatch_pct = (dist / max_len) * 100
94
+
95
+ # Simple character mismatch finding
96
+ import difflib
97
+ matcher = difflib.SequenceMatcher(None, actual_transcript, transcript)
98
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
99
+ if tag in ['replace', 'insert']:
100
+ mismatched_chars.extend(list(transcript[j1:j2]))
101
+
102
+ # Determine Severity based on mismatch percentage
103
+ severity = "none"
104
+ if mismatch_pct > 10: severity = "mild"
105
+ if mismatch_pct > 25: severity = "moderate"
106
+ if mismatch_pct > 45: severity = "severe"
107
+
108
+ # 6. Format Response
109
+ response_data = {
110
+ "actual_transcript": actual_transcript,
111
+ "target_transcript": transcript or "",
112
+ "mismatched_chars": mismatched_chars,
113
+ "mismatch_percentage": round(mismatch_pct, 2),
114
+ "ctc_loss_score": 0.0,
115
+ "stutter_timestamps": [],
116
+ "total_stutter_duration": 0.0,
117
+ "stutter_frequency": 0.0,
118
+ "severity": severity,
119
+ "confidence_score": round(confidence, 2),
120
+ "model_version": MODEL_ID,
121
+ "language_detected": target_lang
122
+ }
123
+
124
+ return response_data
125
+
126
+ except Exception as e:
127
+ import traceback
128
+ traceback.print_exc()
129
+ return {"error": str(e)}, 500
130
+
131
+ finally:
132
+ # Cleanup: Delete the temporary file
133
+ if os.path.exists(temp_filename):
134
+ os.remove(temp_filename)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ python-multipart==0.0.6
4
+ torch==2.1.0
5
+ transformers==4.35.2
6
+ librosa==0.10.1
7
+ numpy==1.26.2
8
+ scipy==1.11.4
9
+ soundfile==0.12.1
10
+ python-Levenshtein==0.23.0