Update main.py
Browse files
main.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torchaudio
|
| 4 |
-
from
|
| 5 |
-
from
|
|
|
|
| 6 |
from transformers import (
|
| 7 |
WhisperProcessor,
|
| 8 |
WhisperForConditionalGeneration,
|
|
@@ -30,8 +31,14 @@ logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
| 30 |
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
| 31 |
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 32 |
|
| 33 |
-
app =
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# ========== Load Whisper Model (quantized) ==========
|
| 37 |
def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
|
|
@@ -131,32 +138,41 @@ def warm_up_models():
|
|
| 131 |
warm_up_models()
|
| 132 |
|
| 133 |
# ========== Flask Route ==========
|
| 134 |
-
@app.
|
| 135 |
-
def transcribe():
|
| 136 |
-
if
|
| 137 |
-
|
| 138 |
|
| 139 |
-
audio_file = request.files['audio']
|
| 140 |
os.makedirs("/tmp/temp_audio", exist_ok=True)
|
| 141 |
-
audio_path = f"/tmp/temp_audio/{
|
| 142 |
-
audio_file.save(audio_path)
|
| 143 |
|
|
|
|
| 144 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
transcription = transcribe_long_audio(audio_path, processor, whisper_model)
|
| 146 |
corrected_text = correct_grammar(transcription, grammar_pipeline)
|
| 147 |
|
| 148 |
-
return
|
| 149 |
"raw_transcription": transcription,
|
| 150 |
"corrected_transcription": corrected_text
|
| 151 |
})
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
-
|
| 155 |
|
| 156 |
finally:
|
| 157 |
-
|
| 158 |
-
os.
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
# ========== Run App ==========
|
| 161 |
if __name__ == '__main__':
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torchaudio
|
| 4 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
from transformers import (
|
| 8 |
WhisperProcessor,
|
| 9 |
WhisperForConditionalGeneration,
|
|
|
|
| 31 |
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
| 32 |
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 33 |
|
| 34 |
+
app = FastAPI()
|
| 35 |
+
app.add_middleware(
|
| 36 |
+
CORSMiddleware,
|
| 37 |
+
allow_origins=["*"],
|
| 38 |
+
allow_credentials=True,
|
| 39 |
+
allow_methods=["*"],
|
| 40 |
+
allow_headers=["*"],
|
| 41 |
+
)
|
| 42 |
|
| 43 |
# ========== Load Whisper Model (quantized) ==========
|
| 44 |
def load_whisper_model(model_size="small", save_dir="/tmp/models_cache/whisper"):
|
|
|
|
| 138 |
warm_up_models()
|
| 139 |
|
| 140 |
# ========== Flask Route ==========
|
| 141 |
+
@app.post('/transcribe')
|
| 142 |
+
async def transcribe(audio: UploadFile = File(...)):
|
| 143 |
+
if not audio:
|
| 144 |
+
raise HTTPException(status_code=400, detail="No audio file provided.")
|
| 145 |
|
|
|
|
| 146 |
os.makedirs("/tmp/temp_audio", exist_ok=True)
|
| 147 |
+
audio_path = f"/tmp/temp_audio/{audio.filename}"
|
|
|
|
| 148 |
|
| 149 |
+
# Save uploaded file
|
| 150 |
try:
|
| 151 |
+
with open(audio_path, "wb") as f:
|
| 152 |
+
content = await audio.read()
|
| 153 |
+
f.write(content)
|
| 154 |
+
|
| 155 |
transcription = transcribe_long_audio(audio_path, processor, whisper_model)
|
| 156 |
corrected_text = correct_grammar(transcription, grammar_pipeline)
|
| 157 |
|
| 158 |
+
return JSONResponse({
|
| 159 |
"raw_transcription": transcription,
|
| 160 |
"corrected_transcription": corrected_text
|
| 161 |
})
|
| 162 |
|
| 163 |
except Exception as e:
|
| 164 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 165 |
|
| 166 |
finally:
|
| 167 |
+
try:
|
| 168 |
+
if os.path.exists(audio_path):
|
| 169 |
+
os.remove(audio_path)
|
| 170 |
+
except Exception:
|
| 171 |
+
pass
|
| 172 |
|
| 173 |
# ========== Run App ==========
|
| 174 |
if __name__ == '__main__':
|
| 175 |
+
# Run with Uvicorn for FastAPI
|
| 176 |
+
import uvicorn
|
| 177 |
+
|
| 178 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")
|