cisckids2026's picture
update
726fa91 verified
import os
import tempfile
import numpy as np
import librosa
import torch
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
# CHANGED: Whisper imports
from transformers import WhisperProcessor, WhisperForConditionalGeneration
app = FastAPI(title="Whisper ASR API")
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_DIR = "cisckids2026/marungko-API-Whisper"
print("Loading processor and model...")
# IMPORTANT: processor must come from base Whisper
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_DIR)
model.to(device)
model.eval()
print("Model loaded successfully on", device)
def load_audio_16k(path: str):
audio, sr = librosa.load(path, sr=16000, mono=True)
# Trim silence
audio, _ = librosa.effects.trim(audio, top_db=20)
# Normalize
max_val = np.max(np.abs(audio))
if max_val > 0:
audio = audio / max_val
return audio
def transcribe_array(audio: np.ndarray) -> str:
# Whisper input features
inputs = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
)
input_features = inputs.input_features.to(device)
with torch.no_grad():
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
@app.get("/")
def root():
return {
"message": "Whisper ASR API is running",
"device": device
}
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
if not file.filename:
raise HTTPException(status_code=400, detail="No file uploaded.")
suffix = os.path.splitext(file.filename)[1].lower()
# Supported formats
if suffix not in [".wav", ".mp3", ".m4a", ".aac", ".flac", ".ogg", ".caf"]:
raise HTTPException(
status_code=400,
detail="Unsupported audio format."
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(await file.read())
temp_path = temp_file.name
audio = load_audio_16k(temp_path)
transcript = transcribe_array(audio)
return JSONResponse({
"status": "success",
"filename": file.filename,
"transcript": transcript
})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)