jgakwaya's picture
Update app.py
b93f33b verified
import os
import torch
import soundfile as sf
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import tempfile
# -----------------------------
# App
# -----------------------------
app = FastAPI()
MODEL_ID = "benax-rw/KinyaWhisper"
# -----------------------------
# Auth (GATED MODEL)
# -----------------------------
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
if not HF_TOKEN:
raise RuntimeError(
"HUGGINGFACE_HUB_TOKEN not found. Add it in Space β†’ Settings β†’ Secrets."
)
# -----------------------------
# Device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"πŸ”₯ Loading KinyaWhisper on {device}")
# -----------------------------
# Load model
# -----------------------------
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN
)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
low_cpu_mem_usage=True,
token=HF_TOKEN
).to(device)
model.eval()
# -----------------------------
# Health check
# -----------------------------
@app.get("/")
def root():
return {
"status": "ok",
"model": MODEL_ID,
"device": device
}
# -----------------------------
# Transcription endpoint
# -----------------------------
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
# Save uploaded file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
# Load audio
audio, sr = sf.read(tmp_path)
if sr != 16000:
raise ValueError("Audio must be 16kHz mono WAV")
# Process audio
inputs = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
)
input_features = inputs.input_features.to(device, dtype=dtype)
# -----------------------------
# πŸ”΄ IMPORTANT: REPETITION FIX
# -----------------------------
with torch.no_grad():
generated_ids = model.generate(
input_features,
task="transcribe",
max_new_tokens=256,
temperature=0.0, # πŸ”’ deterministic
no_repeat_ngram_size=3, # πŸ” stop loops
repetition_penalty=1.2 # πŸ›‘ penalize repeats
)
text = processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
return {"text": text.strip()}