import json import os import uuid import shutil import threading import numpy as np from datetime import datetime, timezone from pathlib import Path import librosa import torch from dotenv import load_dotenv from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from huggingface_hub import HfApi, create_repo, upload_file from transformers import WhisperForConditionalGeneration, WhisperProcessor # ── Env ──────────────────────────────────────────────────────────────────── load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") DATASET_REPO = os.getenv("HF_DATASET_REPO") # ── Paths ────────────────────────────────────────────────────────────────── BASE_DIR = Path(__file__).parent STATIC_DIR = BASE_DIR / "static" DATASET_DIR = BASE_DIR / "dataset" AUDIO_DIR = DATASET_DIR / "audio" MANIFEST = DATASET_DIR / "transcripts.jsonl" STATIC_DIR.mkdir(exist_ok=True) DATASET_DIR.mkdir(exist_ok=True) AUDIO_DIR.mkdir(exist_ok=True) # ── HuggingFace setup ────────────────────────────────────────────────────── _hf_api: HfApi | None = None def get_hf_api() -> HfApi | None: global _hf_api if _hf_api is None and HF_TOKEN: _hf_api = HfApi(token=HF_TOKEN) try: create_repo( repo_id=DATASET_REPO, repo_type="dataset", exist_ok=True, token=HF_TOKEN, ) except Exception as e: print(f"[HF] Dataset repo check: {e}") return _hf_api # ── Model (lazy-loaded on first request) ────────────────────────────────── _MODEL: WhisperForConditionalGeneration | None = None _PROCESSOR: WhisperProcessor | None = None _DEVICE: torch.device | None = None def get_model(): global _MODEL, _PROCESSOR, _DEVICE if _MODEL is None: print("Loading Kennethdot/kasanoma_whisper …") _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _MODEL = WhisperForConditionalGeneration.from_pretrained( "Kennethdot/kasanoma_whisper", torch_dtype=torch.float16 if _DEVICE.type == "cuda" else torch.float32, ).to(_DEVICE) _MODEL.eval() _PROCESSOR = WhisperProcessor.from_pretrained( "Kennethdot/kasanoma_whisper" ) print(f"Model ready on {_DEVICE}.") return _MODEL, _PROCESSOR, _DEVICE # ── App ──────────────────────────────────────────────────────────────────── app = FastAPI(title="Kasanoma ASR", version="2.1.0") _csv_lock = threading.Lock() # ── Helpers ──────────────────────────────────────────────────────────────── def audio_duration(path: Path) -> float: try: y, sr = librosa.load(str(path), sr=None, mono=True) return round(len(y) / sr, 2) except Exception: return 0.0 def save_entry(entry: dict) -> None: with MANIFEST.open("a", encoding="utf-8") as f: f.write(json.dumps(entry, ensure_ascii=False) + "\n") def load_manifest() -> list[dict]: if not MANIFEST.exists(): return [] entries = [] with MANIFEST.open(encoding="utf-8") as f: for line in f: line = line.strip() if line: entries.append(json.loads(line)) return entries def _upload_in_background(audio_path: Path, relative_audio_path: str) -> None: api = get_hf_api() if api is None: return try: upload_file( path_or_fileobj=str(audio_path), path_in_repo=relative_audio_path, repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, ) with _csv_lock: upload_file( path_or_fileobj=str(MANIFEST), path_in_repo="transcripts.jsonl", repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN, ) except Exception as e: print(f"[Background upload error] {e}") def transcribe_path(audio_path: Path) -> tuple[str, str]: model, processor, device = get_model() # librosa decodes any format (webm, mp4, ogg, wav) via ffmpeg, # returns mono float32 already resampled to 16 kHz audio_data, _ = librosa.load(str(audio_path), sr=16000, mono=True) # Normalise peak = np.max(np.abs(audio_data)) if peak > 0: audio_data = audio_data / peak # Feature extraction — use the full processor so we get attention_mask inputs = processor( audio_data, sampling_rate=16000, return_tensors="pt", return_attention_mask=True, ) input_features = inputs.input_features.to(device) attention_mask = inputs.attention_mask.to(device) # Cast to fp16 on GPU for speed if device.type == "cuda": input_features = input_features.half() with torch.no_grad(): generated_ids = model.generate( input_features, attention_mask=attention_mask, task="transcribe", language="yo", # Twi — not "yo" (Yoruba) temperature=0.0, forced_decoder_ids=None, # avoids duplicate logits processor warnings ) transcription = processor.batch_decode( generated_ids, skip_special_tokens=True )[0].strip() return transcription, "tw" # ── Routes ───────────────────────────────────────────────────────────────── @app.get("/", include_in_schema=False) async def root(): index = STATIC_DIR / "index.html" if not index.exists(): raise HTTPException(404, "Frontend not found. Place index.html in static/") return FileResponse(index) @app.post("/transcribe") async def transcribe(audio: UploadFile = File(...)): suffix = Path(audio.filename or "audio.webm").suffix or ".webm" tmp_path = BASE_DIR / f"_tmp_{uuid.uuid4().hex}{suffix}" try: with tmp_path.open("wb") as f: shutil.copyfileobj(audio.file, f) transcription, detected_lang = transcribe_path(tmp_path) return JSONResponse({ "transcription": transcription, "detected_language": detected_lang, }) except Exception as e: print(f"[Transcribe error] {e}") raise HTTPException(500, detail=str(e)) finally: tmp_path.unlink(missing_ok=True) @app.post("/save") async def save( audio: UploadFile = File(...), transcription: str = Form(...), ): if not transcription.strip(): raise HTTPException(422, "Transcription must not be empty.") entry_id = uuid.uuid4().hex suffix = Path(audio.filename or "audio.webm").suffix or ".webm" audio_filename = f"{entry_id}{suffix}" audio_path = AUDIO_DIR / audio_filename with audio_path.open("wb") as f: shutil.copyfileobj(audio.file, f) duration = audio_duration(audio_path) entry = { "id": entry_id, "audio_file": f"dataset/audio/{audio_filename}", "transcription": transcription.strip(), "language": "twi_en", "duration_s": duration, "created_at": datetime.now(timezone.utc).isoformat(), } with _csv_lock: save_entry(entry) relative_audio_path = f"audio/{audio_filename}" threading.Thread( target=_upload_in_background, args=(audio_path, relative_audio_path), daemon=True, ).start() total = len(load_manifest()) return JSONResponse({ "id": entry_id, "total_saved": total, "duration_s": duration, }) @app.get("/dataset/stats") async def dataset_stats(): entries = load_manifest() if not entries: return JSONResponse({"total": 0, "total_duration_s": 0, "total_words": 0}) total_duration = sum(e.get("duration_s", 0) for e in entries) total_words = sum(len(e["transcription"].split()) for e in entries) return JSONResponse({ "total": len(entries), "total_duration_s": round(total_duration, 1), "total_words": total_words, }) @app.get("/dataset/entries") async def dataset_entries(limit: int = 50, offset: int = 0): entries = load_manifest() entries.reverse() return JSONResponse({ "entries": entries[offset : offset + limit], "total": len(entries), }) # ── Static files ─────────────────────────────────────────────────────────── app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") app.mount("/assets", StaticFiles(directory="assets"), name="assets")