|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
from typing import Optional |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect, Query |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
stream=sys.stdout |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
|
|
|
try: |
|
|
from diagnosis.ai_engine.model_loader import ( |
|
|
get_inference_pipeline |
|
|
) |
|
|
from ui.gradio_interface import create_gradio_interface |
|
|
from config import APIConfig, GradioConfig, default_api_config, default_gradio_config |
|
|
logger.info("β
Successfully imported model loaders and UI components") |
|
|
except ImportError as e: |
|
|
logger.error(f"β Failed to import required modules: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Speech Pathology Diagnosis API", |
|
|
description="Speech analysis using Wav2Vec2-XLSR-53 for fluency and articulation diagnosis", |
|
|
version="2.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
inference_pipeline = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load models on startup""" |
|
|
global inference_pipeline |
|
|
try: |
|
|
logger.info("π Startup event: Loading AI models...") |
|
|
|
|
|
|
|
|
try: |
|
|
inference_pipeline = get_inference_pipeline() |
|
|
logger.info("β
Inference pipeline loaded") |
|
|
|
|
|
|
|
|
try: |
|
|
from api.routes import initialize_routes |
|
|
from api.streaming import initialize_streaming |
|
|
initialize_routes(inference_pipeline) |
|
|
initialize_streaming(inference_pipeline) |
|
|
logger.info("β
API routes initialized with phoneme/error mappers") |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ API routes initialization failed: {e}", exc_info=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load inference pipeline: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
logger.info("β
Models loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load models: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
from api.routes import router as diagnose_router |
|
|
app.include_router(diagnose_router) |
|
|
logger.info("β
Diagnosis router included") |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Failed to include diagnosis router: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from api.streaming import handle_streaming_websocket |
|
|
@app.websocket("/ws/diagnose") |
|
|
async def websocket_diagnose(websocket: WebSocket, session_id: Optional[str] = None): |
|
|
await handle_streaming_websocket(websocket, session_id) |
|
|
logger.info("β
WebSocket endpoint registered") |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Failed to register WebSocket endpoint: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
gradio_interface = create_gradio_interface(default_gradio_config) |
|
|
gr.mount_gradio_app(app, gradio_interface, path="/") |
|
|
logger.info("β
Gradio interface mounted at /") |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to create Gradio interface: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
""" |
|
|
Health check endpoint. |
|
|
|
|
|
Returns: |
|
|
Health status with model loading information |
|
|
""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"models_loaded": { |
|
|
"inference_pipeline": inference_pipeline is not None, |
|
|
"model_version": "wav2vec2-xlsr-53-v2" |
|
|
}, |
|
|
"timestamp": datetime.utcnow().isoformat() + "Z" |
|
|
} |
|
|
|
|
|
@app.post("/api/diagnose") |
|
|
async def diagnose_speech( |
|
|
audio: UploadFile = File(...), |
|
|
text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping (optional)") |
|
|
): |
|
|
""" |
|
|
Legacy endpoint for speech diagnosis. |
|
|
|
|
|
NOTE: For full phoneme-level error detection with therapy recommendations, |
|
|
use POST /diagnose/file?text=<expected_text> instead. |
|
|
This endpoint is maintained for backward compatibility. |
|
|
|
|
|
Parameters: |
|
|
- audio: Audio file (WAV, MP3, FLAC, M4A) |
|
|
- text: Optional expected text for phoneme mapping |
|
|
|
|
|
Returns: |
|
|
Dictionary with diagnosis results (legacy format for backward compatibility) |
|
|
""" |
|
|
if not inference_pipeline: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Inference pipeline not loaded yet. Try again in a moment." |
|
|
) |
|
|
|
|
|
|
|
|
from api.routes import get_phoneme_mapper, get_error_mapper |
|
|
from models.error_taxonomy import ErrorType |
|
|
|
|
|
start_time = time.time() |
|
|
temp_file = None |
|
|
|
|
|
try: |
|
|
logger.info(f"π₯ Processing legacy diagnosis request: {audio.filename}") |
|
|
|
|
|
|
|
|
file_ext = Path(audio.filename).suffix.lower() |
|
|
allowed_extensions = default_api_config.allowed_extensions |
|
|
if file_ext not in allowed_extensions: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Unsupported file type: {file_ext}. Allowed: {allowed_extensions}" |
|
|
) |
|
|
|
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
temp_file = os.path.join(temp_dir, f"diagnosis_{int(time.time())}_{audio.filename}") |
|
|
content = await audio.read() |
|
|
|
|
|
|
|
|
file_size_mb = len(content) / 1024 / 1024 |
|
|
if file_size_mb > default_api_config.max_file_size_mb: |
|
|
raise HTTPException( |
|
|
status_code=413, |
|
|
detail=f"File too large: {file_size_mb:.2f}MB. Max: {default_api_config.max_file_size_mb}MB" |
|
|
) |
|
|
|
|
|
with open(temp_file, "wb") as f: |
|
|
f.write(content) |
|
|
|
|
|
logger.info(f"π Saved to: {temp_file} ({file_size_mb:.2f} MB)") |
|
|
|
|
|
|
|
|
logger.info("π Running inference pipeline...") |
|
|
result = inference_pipeline.predict_phone_level( |
|
|
temp_file, |
|
|
return_timestamps=True |
|
|
) |
|
|
|
|
|
processing_time_ms = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
phoneme_mapper = get_phoneme_mapper() |
|
|
error_mapper = get_error_mapper() |
|
|
|
|
|
|
|
|
frame_phonemes = [] |
|
|
errors = [] |
|
|
if text and phoneme_mapper and error_mapper: |
|
|
try: |
|
|
frame_phonemes = phoneme_mapper.map_text_to_frames( |
|
|
text, |
|
|
num_frames=result.num_frames, |
|
|
audio_duration=result.duration |
|
|
) |
|
|
|
|
|
|
|
|
for i, frame_pred in enumerate(result.frame_predictions): |
|
|
phoneme = frame_phonemes[i] if i < len(frame_phonemes) else '' |
|
|
class_id = frame_pred.articulation_class |
|
|
if frame_pred.fluency_label == 'stutter': |
|
|
class_id += 4 |
|
|
|
|
|
error_detail = error_mapper.map_classifier_output( |
|
|
class_id=class_id, |
|
|
confidence=frame_pred.confidence, |
|
|
phoneme=phoneme if phoneme else 'unknown', |
|
|
fluency_label=frame_pred.fluency_label |
|
|
) |
|
|
|
|
|
if error_detail.error_type != ErrorType.NORMAL: |
|
|
errors.append({ |
|
|
"phoneme": error_detail.phoneme, |
|
|
"time": frame_pred.time, |
|
|
"error_type": error_detail.error_type.value, |
|
|
"wrong_sound": error_detail.wrong_sound, |
|
|
"severity": error_mapper.get_severity_level(error_detail.severity).value, |
|
|
"therapy": error_detail.therapy |
|
|
}) |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Phoneme/error mapping failed: {e}") |
|
|
|
|
|
|
|
|
aggregate = result.aggregate |
|
|
mean_fluency_stutter = aggregate.get("fluency_score", 0.0) |
|
|
fluency_percentage = (1.0 - mean_fluency_stutter) * 100 |
|
|
|
|
|
fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal') |
|
|
fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0 |
|
|
|
|
|
articulation_class_counts = {} |
|
|
for fp in result.frame_predictions: |
|
|
label = fp.articulation_label |
|
|
articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1 |
|
|
|
|
|
dominant_articulation = aggregate.get("articulation_label", "normal") |
|
|
avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0 |
|
|
|
|
|
|
|
|
response = { |
|
|
"status": "success", |
|
|
"fluency_metrics": { |
|
|
"mean_fluency": fluency_percentage / 100.0, |
|
|
"fluency_percentage": fluency_percentage, |
|
|
"fluent_frames_ratio": fluent_frames_ratio, |
|
|
"fluent_frames_percentage": fluent_frames_ratio * 100, |
|
|
"stutter_probability": mean_fluency_stutter |
|
|
}, |
|
|
"articulation_results": { |
|
|
"total_frames": result.num_frames, |
|
|
"frame_duration_ms": int(inference_pipeline.inference_config.hop_size_ms), |
|
|
"dominant_class": aggregate.get("articulation_class", 0), |
|
|
"dominant_label": dominant_articulation, |
|
|
"class_distribution": articulation_class_counts, |
|
|
"frame_predictions": [ |
|
|
{ |
|
|
"time": fp.time, |
|
|
"fluency_prob": fp.fluency_prob, |
|
|
"fluency_label": fp.fluency_label, |
|
|
"articulation_class": fp.articulation_class, |
|
|
"articulation_label": fp.articulation_label, |
|
|
"confidence": fp.confidence, |
|
|
"phoneme": frame_phonemes[i] if i < len(frame_phonemes) else '' |
|
|
} |
|
|
for i, fp in enumerate(result.frame_predictions) |
|
|
] |
|
|
}, |
|
|
"confidence": avg_confidence, |
|
|
"confidence_percentage": avg_confidence * 100, |
|
|
"processing_time_ms": processing_time_ms |
|
|
} |
|
|
|
|
|
|
|
|
if errors: |
|
|
response["error_count"] = len(errors) |
|
|
response["errors"] = errors[:10] |
|
|
response["problematic_sounds"] = list(set(err["phoneme"] for err in errors if err["phoneme"])) |
|
|
|
|
|
logger.info(f"β
Legacy diagnosis complete: fluency={response['fluency_metrics']['fluency_percentage']:.1f}%, " |
|
|
f"errors={len(errors) if errors else 0}, " |
|
|
f"time={processing_time_ms:.0f}ms") |
|
|
|
|
|
return response |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"β Error during diagnosis: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Diagnosis failed: {str(e)}") |
|
|
|
|
|
finally: |
|
|
|
|
|
if temp_file and os.path.exists(temp_file): |
|
|
try: |
|
|
os.remove(temp_file) |
|
|
logger.debug(f"π§Ή Cleaned up: {temp_file}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not clean up {temp_file}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.websocket("/ws/audio") |
|
|
async def websocket_audio_stream(websocket: WebSocket): |
|
|
""" |
|
|
WebSocket endpoint for real-time audio streaming. |
|
|
|
|
|
Receives audio chunks and returns real-time predictions. |
|
|
""" |
|
|
await websocket.accept() |
|
|
logger.info("π WebSocket connection established") |
|
|
|
|
|
try: |
|
|
from audio.audio_processor import StreamingAudioBuffer |
|
|
from config import default_audio_config |
|
|
|
|
|
|
|
|
buffer = StreamingAudioBuffer( |
|
|
buffer_duration_ms=1000.0, |
|
|
chunk_duration_ms=default_audio_config.chunk_duration_ms, |
|
|
sample_rate=default_audio_config.sample_rate |
|
|
) |
|
|
|
|
|
if not inference_pipeline: |
|
|
await websocket.send_json({ |
|
|
"error": "Inference pipeline not loaded", |
|
|
"status": "error" |
|
|
}) |
|
|
await websocket.close() |
|
|
return |
|
|
|
|
|
frame_index = 0 |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
data = await websocket.receive_bytes() |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0 |
|
|
|
|
|
|
|
|
buffer.add_chunk(audio_chunk) |
|
|
|
|
|
|
|
|
if buffer.has_enough_data(): |
|
|
chunk = buffer.get_chunk() |
|
|
if chunk is not None: |
|
|
|
|
|
result = inference_pipeline.predict_streaming( |
|
|
chunk, |
|
|
frame_index=frame_index, |
|
|
timestamp_ms=frame_index * default_audio_config.chunk_duration_ms |
|
|
) |
|
|
|
|
|
|
|
|
await websocket.send_json({ |
|
|
"status": "success", |
|
|
"frame_index": frame_index, |
|
|
"fluency_score": result.fluency_score, |
|
|
"articulation_class": result.articulation_class, |
|
|
"articulation_class_name": result.articulation_class_name, |
|
|
"confidence": result.confidence, |
|
|
"timestamp_ms": result.timestamp_ms |
|
|
}) |
|
|
|
|
|
frame_index += 1 |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
logger.info("π WebSocket disconnected") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"β WebSocket error: {e}", exc_info=True) |
|
|
await websocket.send_json({ |
|
|
"error": str(e), |
|
|
"status": "error" |
|
|
}) |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β WebSocket setup failed: {e}", exc_info=True) |
|
|
try: |
|
|
await websocket.send_json({ |
|
|
"error": str(e), |
|
|
"status": "error" |
|
|
}) |
|
|
await websocket.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
from config import default_api_config |
|
|
|
|
|
logger.info("π Starting Speech Pathology Diagnosis API...") |
|
|
logger.info(f" FastAPI: http://{default_api_config.host}:{default_api_config.port}") |
|
|
logger.info(f" Gradio UI: http://{default_api_config.host}:{default_gradio_config.port}") |
|
|
logger.info(f" WebSocket: ws://{default_api_config.host}:{default_api_config.port}/ws/audio") |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host=default_api_config.host, |
|
|
port=default_api_config.port, |
|
|
log_level="info" |
|
|
) |
|
|
|