Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import soundfile as sf | |
| from fastapi import FastAPI, UploadFile, File | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| import tempfile | |
| # ----------------------------- | |
| # App | |
| # ----------------------------- | |
| app = FastAPI() | |
| MODEL_ID = "benax-rw/KinyaWhisper" | |
| # ----------------------------- | |
| # Auth (GATED MODEL) | |
| # ----------------------------- | |
| HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| if not HF_TOKEN: | |
| raise RuntimeError( | |
| "HUGGINGFACE_HUB_TOKEN not found. Add it in Space β Settings β Secrets." | |
| ) | |
| # ----------------------------- | |
| # Device | |
| # ----------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| print(f"π₯ Loading KinyaWhisper on {device}") | |
| # ----------------------------- | |
| # Load model | |
| # ----------------------------- | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN | |
| ) | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN | |
| ).to(device) | |
| model.eval() | |
| # ----------------------------- | |
| # Health check | |
| # ----------------------------- | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "model": MODEL_ID, | |
| "device": device | |
| } | |
| # ----------------------------- | |
| # Transcription endpoint | |
| # ----------------------------- | |
| async def transcribe(file: UploadFile = File(...)): | |
| # Save uploaded file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| tmp_path = tmp.name | |
| # Load audio | |
| audio, sr = sf.read(tmp_path) | |
| if sr != 16000: | |
| raise ValueError("Audio must be 16kHz mono WAV") | |
| # Process audio | |
| inputs = processor( | |
| audio, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| input_features = inputs.input_features.to(device, dtype=dtype) | |
| # ----------------------------- | |
| # π΄ IMPORTANT: REPETITION FIX | |
| # ----------------------------- | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| input_features, | |
| task="transcribe", | |
| max_new_tokens=256, | |
| temperature=0.0, # π deterministic | |
| no_repeat_ngram_size=3, # π stop loops | |
| repetition_penalty=1.2 # π penalize repeats | |
| ) | |
| text = processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return {"text": text.strip()} |