Spaces:
Running
Running
| import os | |
| import shutil | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import requests | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from pymongo import MongoClient | |
| load_dotenv() | |
| MONGODB_URI = os.getenv("MONGODB_URI", "mongodb://localhost:27017/") | |
| WHISPER_MODEL_SIZE = os.getenv("WHISPER_MODEL", "small") | |
| WHISPER_API_URL = os.getenv("WHISPER_API_URL", "").strip().rstrip("/") | |
| DISFLUENCY_API_URL = os.getenv("DISFLUENCY_API_URL", "").strip().rstrip("/") | |
| REMOTE_API_TIMEOUT = int(os.getenv("REMOTE_API_TIMEOUT", "300")) | |
| HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
| UPLOAD_DIR = Path("uploads") | |
| UPLOAD_DIR.mkdir(exist_ok=True) | |
| UI_DIR = Path(__file__).parent / "ui" | |
| client = MongoClient(MONGODB_URI) | |
| db = client["SignApp"] | |
| sign_rules_col = db["sign_rules"] | |
| fingerspell_col = db["fingerspelling"] | |
| _whisper_model = None | |
| _disfluency_fn = None | |
| def _auth_headers() -> dict[str, str]: | |
| if not HF_TOKEN: | |
| return {} | |
| return {"Authorization": f"Bearer {HF_TOKEN}"} | |
| def get_whisper(): | |
| global _whisper_model | |
| if _whisper_model is None: | |
| import whisper | |
| _whisper_model = whisper.load_model(WHISPER_MODEL_SIZE) | |
| return _whisper_model | |
| def get_disfluency_fn(): | |
| global _disfluency_fn | |
| if _disfluency_fn is None: | |
| from .disfluency.inference import remove_disfluency | |
| _disfluency_fn = remove_disfluency | |
| return _disfluency_fn | |
| def transcribe_audio(file_path: Path) -> dict: | |
| if WHISPER_API_URL: | |
| with file_path.open("rb") as audio_file: | |
| response = requests.post( | |
| f"{WHISPER_API_URL}/transcribe/", | |
| headers=_auth_headers(), | |
| files={"file": (file_path.name, audio_file, "audio/webm")}, | |
| timeout=REMOTE_API_TIMEOUT, | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return { | |
| "text": data.get("text", ""), | |
| "language": data.get("language", "en"), | |
| } | |
| whisper_model = get_whisper() | |
| result = whisper_model.transcribe(str(file_path), language="en") | |
| return { | |
| "text": result["text"], | |
| "language": result["language"], | |
| } | |
| def clean_disfluency(text: str) -> str: | |
| if DISFLUENCY_API_URL: | |
| response = requests.post( | |
| f"{DISFLUENCY_API_URL}/clean/", | |
| headers=_auth_headers(), | |
| json={"text": text}, | |
| timeout=REMOTE_API_TIMEOUT, | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data.get("cleaned_text", "").strip() | |
| return get_disfluency_fn()(text) | |
| async def lifespan(app: FastAPI): | |
| if not WHISPER_API_URL: | |
| print("Loading local Whisper model on startup...") | |
| get_whisper() | |
| else: | |
| print(f"Using remote Whisper API: {WHISPER_API_URL}") | |
| if not DISFLUENCY_API_URL: | |
| print("Loading local disfluency model on startup...") | |
| get_disfluency_fn() | |
| else: | |
| print(f"Using remote disfluency API: {DISFLUENCY_API_URL}") | |
| print("SignApp startup complete.") | |
| yield | |
| app = FastAPI(title="SignApp", version="0.1.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| from .sign_language_text.gloss_converter import convert_to_sign_gloss | |
| class TextInput(BaseModel): | |
| text: str | |
| def build_sign_sequence(gloss_tokens: list[str]) -> list[dict]: | |
| """Look up each gloss token in MongoDB sign_rules, fall back to fingerspelling.""" | |
| sign_sequence = [] | |
| for word in gloss_tokens: | |
| rule = sign_rules_col.find_one({"sign": word}) | |
| if rule: | |
| sign_sequence.append( | |
| { | |
| "type": "sign", | |
| "gloss": word, | |
| "handshape": rule["handshape"], | |
| "location": rule["location"], | |
| "movement": rule["movement"], | |
| "expression": rule.get("expression", "neutral"), | |
| } | |
| ) | |
| else: | |
| for letter in word: | |
| finger = fingerspell_col.find_one({"letter": letter.upper()}) | |
| if finger: | |
| sign_sequence.append( | |
| { | |
| "type": "fingerspell", | |
| "letter": letter.upper(), | |
| "handshape": finger["handshape"], | |
| "location": "neutral_space", | |
| "movement": finger.get("movement") or "none", | |
| } | |
| ) | |
| return sign_sequence | |
| def text_pipeline(text: str) -> dict: | |
| cleaned_text = clean_disfluency(text) | |
| sign_friendly_text = convert_to_sign_gloss(cleaned_text) | |
| sign_sequence = build_sign_sequence(sign_friendly_text) | |
| return { | |
| "cleaned_transcription": cleaned_text, | |
| "sign_friendly_text": sign_friendly_text, | |
| "sign_sequence": sign_sequence, | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "whisper": "remote" if WHISPER_API_URL else "local", | |
| "disfluency": "remote" if DISFLUENCY_API_URL else "local", | |
| } | |
| def voice_to_text_endpoint(file: UploadFile = File(...)): | |
| """Full pipeline: audio -> transcription -> gloss -> sign sequence.""" | |
| file_path = UPLOAD_DIR / (file.filename or "recording.webm") | |
| try: | |
| with file_path.open("wb") as audio_file: | |
| shutil.copyfileobj(file.file, audio_file) | |
| transcription_result = transcribe_audio(file_path) | |
| transcription = transcription_result["text"] | |
| language = transcription_result["language"] | |
| result = text_pipeline(transcription) | |
| return { | |
| "language": language, | |
| "raw_transcription": transcription, | |
| **result, | |
| } | |
| except requests.RequestException as exc: | |
| raise HTTPException(status_code=502, detail=f"Remote model service failed: {exc}") from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| finally: | |
| if file_path.exists(): | |
| file_path.unlink() | |
| def text_to_sign_endpoint(body: TextInput): | |
| """Text-only pipeline: text -> gloss -> sign sequence.""" | |
| text = body.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Text is empty") | |
| try: | |
| return text_pipeline(text) | |
| except requests.RequestException as exc: | |
| raise HTTPException(status_code=502, detail=f"Remote model service failed: {exc}") from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| def serve_ui(): | |
| return FileResponse(UI_DIR / "index.html") | |
| app.mount("/", StaticFiles(directory=str(UI_DIR)), name="ui") | |