Ellie5757575757's picture
Update app.py
3f2b9ca verified
raw
history blame
16 kB
import gradio as gr
from flask import Flask
import os
import tempfile
import logging
import threading
import time
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration
MODEL_DIR = "."
SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
def safe_import_modules():
"""Safely import pipeline modules with error handling"""
modules = {}
try:
from utils_audio import convert_to_wav
modules['convert_to_wav'] = convert_to_wav
logger.info("βœ“ utils_audio imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import utils_audio: {e}")
modules['convert_to_wav'] = None
try:
from to_cha import to_cha_from_wav
modules['to_cha_from_wav'] = to_cha_from_wav
logger.info("βœ“ to_cha imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import to_cha: {e}")
modules['to_cha_from_wav'] = None
try:
from cha_json import cha_to_json_file
modules['cha_to_json_file'] = cha_to_json_file
logger.info("βœ“ cha_json imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import cha_json: {e}")
modules['cha_to_json_file'] = None
try:
from output import predict_from_chajson
modules['predict_from_chajson'] = predict_from_chajson
logger.info("βœ“ output imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import output: {e}")
modules['predict_from_chajson'] = None
return modules
# Import modules
MODULES = safe_import_modules()
def check_model_files():
"""Check if required model files exist"""
required_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"tokenizer_config.json"
]
missing_files = []
for file in required_files:
if not os.path.exists(os.path.join(MODEL_DIR, file)):
missing_files.append(file)
return len(missing_files) == 0, missing_files
def run_complete_pipeline(audio_file_path: str) -> dict:
"""Complete pipeline: Audio β†’ WAV β†’ CHA β†’ JSON β†’ Model Prediction"""
# Check if all modules are available
if not all(MODULES.values()):
missing = [k for k, v in MODULES.items() if v is None]
return {
"success": False,
"error": f"Missing required modules: {missing}",
"message": "Pipeline modules not available"
}
try:
logger.info(f"Starting pipeline for: {audio_file_path}")
# Step 1: Convert to WAV
logger.info("Step 1: Converting audio to WAV...")
wav_path = MODULES['convert_to_wav'](audio_file_path, sr=16000, mono=True)
logger.info(f"WAV conversion completed: {wav_path}")
# Step 2: Generate CHA file using Batchalign
logger.info("Step 2: Generating CHA file...")
cha_path = MODULES['to_cha_from_wav'](wav_path, lang="eng")
logger.info(f"CHA generation completed: {cha_path}")
# Step 3: Convert CHA to JSON
logger.info("Step 3: Converting CHA to JSON...")
chajson_path, json_data = MODULES['cha_to_json_file'](cha_path)
logger.info(f"JSON conversion completed: {chajson_path}")
# Step 4: Run aphasia classification
logger.info("Step 4: Running aphasia classification...")
results = MODULES['predict_from_chajson'](MODEL_DIR, chajson_path, output_file=None)
logger.info("Classification completed")
# Cleanup temporary files
try:
os.unlink(wav_path)
os.unlink(cha_path)
os.unlink(chajson_path)
except Exception as cleanup_error:
logger.warning(f"Cleanup error: {cleanup_error}")
return {
"success": True,
"results": results,
"message": "Pipeline completed successfully"
}
except Exception as e:
logger.error(f"Pipeline error: {str(e)}")
import traceback
traceback.print_exc()
return {
"success": False,
"error": str(e),
"message": f"Pipeline failed: {str(e)}"
}
def process_audio_input(audio_file):
"""Process audio file and return formatted results"""
try:
if audio_file is None:
return "❌ Error: No audio file uploaded"
# Check if pipeline is available
if not all(MODULES.values()):
missing_modules = [k for k, v in MODULES.items() if v is None]
return f"❌ Error: Audio processing pipeline not available. Missing required modules: {', '.join(missing_modules)}"
# Check file format
file_path = audio_file
if hasattr(audio_file, 'name'):
file_path = audio_file.name
from pathlib import Path
file_ext = Path(file_path).suffix.lower()
if file_ext not in SUPPORTED_AUDIO_FORMATS:
return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
# Run the complete pipeline
pipeline_result = run_complete_pipeline(file_path)
if not pipeline_result["success"]:
return f"❌ Pipeline Error: {pipeline_result['message']}\n\nDetails: {pipeline_result.get('error', '')}"
# Format results
results = pipeline_result["results"]
if "predictions" in results and len(results["predictions"]) > 0:
first_pred = results["predictions"][0]
if "error" in first_pred:
return f"❌ Classification Error: {first_pred['error']}"
# Format main result
predicted_class = first_pred["prediction"]["predicted_class"]
confidence = first_pred["prediction"]["confidence_percentage"]
class_name = first_pred["class_description"]["name"]
description = first_pred["class_description"]["description"]
# Additional metrics
additional_info = first_pred["additional_predictions"]
severity_level = additional_info["predicted_severity_level"]
fluency_score = additional_info["fluency_score"]
fluency_rating = additional_info["fluency_rating"]
# Format probability distribution (top 3)
prob_dist = first_pred["probability_distribution"]
top_3 = list(prob_dist.items())[:3]
result_text = f"""🧠 **APHASIA CLASSIFICATION RESULTS**
🎯 **Primary Classification:** {predicted_class}
πŸ“Š **Confidence:** {confidence}
πŸ“‹ **Type:** {class_name}
πŸ“ˆ **Additional Metrics:**
β€’ Severity Level: {severity_level}/3
β€’ Fluency Score: {fluency_score:.3f} ({fluency_rating})
πŸ“Š **Top 3 Probability Rankings:**
"""
for i, (aphasia_type, info) in enumerate(top_3, 1):
result_text += f"{i}. {aphasia_type}: {info['percentage']}\n"
result_text += f"""
πŸ“ **Clinical Description:**
{description}
πŸ“Š **Processing Summary:**
β€’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}
β€’ Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')}
β€’ Average fluency: {results.get('summary', {}).get('average_fluency_score', 'N/A')}
"""
return result_text
else:
return "❌ No predictions generated. The audio file may not contain analyzable speech."
except Exception as e:
logger.error(f"Processing error: {str(e)}")
import traceback
traceback.print_exc()
return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details."
def process_text_input(text_input):
"""Process text input directly (fallback option)"""
try:
if not text_input or not text_input.strip():
return "❌ Error: Please enter some text for analysis"
# Check if prediction module is available
if MODULES['predict_from_chajson'] is None:
return "❌ Error: Text analysis not available. Missing prediction module."
# Create a simple JSON structure for text-only input
import json
temp_json = {
"sentences": [{
"sentence_id": "S1",
"aphasia_type": "UNKNOWN",
"dialogues": [{
"INV": [],
"PAR": [{
"tokens": text_input.split(),
"word_pos_ids": [0] * len(text_input.split()),
"word_grammar_ids": [[0, 0, 0]] * len(text_input.split()),
"word_durations": [0.0] * len(text_input.split()),
"utterance_text": text_input
}]
}]
}],
"text_all": text_input
}
# Save to temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(temp_json, f, ensure_ascii=False, indent=2)
temp_json_path = f.name
# Run prediction
results = MODULES['predict_from_chajson'](MODEL_DIR, temp_json_path, output_file=None)
# Cleanup
try:
os.unlink(temp_json_path)
except:
pass
# Format results
if "predictions" in results and len(results["predictions"]) > 0:
first_pred = results["predictions"][0]
predicted_class = first_pred["prediction"]["predicted_class"]
confidence = first_pred["prediction"]["confidence_percentage"]
description = first_pred["class_description"]["description"]
severity = first_pred["additional_predictions"]["predicted_severity_level"]
fluency = first_pred["additional_predictions"]["fluency_rating"]
return f"""🧠 **TEXT ANALYSIS RESULTS**
🎯 **Predicted:** {predicted_class}
πŸ“Š **Confidence:** {confidence}
πŸ“ˆ **Severity:** {severity}/3
πŸ—£οΈ **Fluency:** {fluency}
πŸ“ **Description:**
{description}
ℹ️ **Note:** Text-based analysis provides limited accuracy compared to audio analysis.
"""
else:
return "❌ No predictions generated from text input"
except Exception as e:
logger.error(f"Text processing error: {str(e)}")
return f"❌ Error: {str(e)}"
def create_gradio_app():
"""Create the Gradio interface"""
# Check system status
model_available, missing_files = check_model_files()
pipeline_available = all(MODULES.values())
status_message = "🟒 System Ready" if model_available and pipeline_available else "πŸ”΄ System Issues"
status_details = []
if not model_available:
status_details.append(f"Missing model files: {', '.join(missing_files)}")
if not pipeline_available:
missing_modules = [k for k, v in MODULES.items() if v is None]
status_details.append(f"Missing modules: {', '.join(missing_modules)}")
# Create simple interfaces to avoid JSON schema issues
audio_demo = gr.Interface(
fn=process_audio_input,
inputs=gr.File(label="Upload Audio File", file_types=["audio"]),
outputs=gr.Textbox(label="Analysis Results", lines=25),
title="🎡 Audio Analysis",
description="Upload MP3, MP4, WAV, M4A, FLAC, or OGG files"
)
text_demo = gr.Interface(
fn=process_text_input,
inputs=gr.Textbox(label="Enter Text", lines=5, placeholder="Enter speech transcription..."),
outputs=gr.Textbox(label="Analysis Results", lines=15),
title="πŸ“ Text Analysis",
description="Enter text for direct analysis (less accurate than audio)"
)
# Combine interfaces using TabbedInterface
demo = gr.TabbedInterface(
[audio_demo, text_demo],
["Audio Analysis", "Text Analysis"],
title="🧠 Aphasia Classification System",
theme=gr.themes.Soft()
)
return demo
def create_flask_app():
"""Create Flask app that serves Gradio"""
# Create Flask app
flask_app = Flask(__name__)
# Create Gradio app
gradio_app = create_gradio_app()
# Mount Gradio app on Flask
gradio_app.queue() # Enable queuing for better performance
# Get the underlying FastAPI app from Gradio
gradio_fastapi_app = gradio_app.app
# Add a health check endpoint
@flask_app.route('/health')
def health_check():
model_available, missing_files = check_model_files()
pipeline_available = all(MODULES.values())
return {
"status": "healthy" if model_available and pipeline_available else "unhealthy",
"model_available": model_available,
"pipeline_available": pipeline_available,
"missing_files": missing_files if not model_available else [],
"missing_modules": [k for k, v in MODULES.items() if v is None] if not pipeline_available else []
}
# Add info endpoint
@flask_app.route('/info')
def info():
return {
"title": "Aphasia Classification System",
"description": "AI-powered aphasia type classification from audio",
"supported_formats": SUPPORTED_AUDIO_FORMATS,
"endpoints": {
"/": "Main Gradio interface",
"/health": "Health check",
"/info": "System information"
}
}
return flask_app, gradio_app
def run_gradio_on_flask():
"""Run Gradio app mounted on Flask"""
logger.info("Starting Aphasia Classification System with Flask + Gradio...")
# Create Flask and Gradio apps
flask_app, gradio_app = create_flask_app()
# Detect environment
port = int(os.environ.get('PORT', 7860))
host = os.environ.get('HOST', '0.0.0.0')
# Check if we're in a cloud environment
is_cloud = any(os.getenv(indicator) for indicator in [
'SPACE_ID', 'PAPERSPACE_NOTEBOOK_REPO_ID',
'COLAB_GPU', 'KAGGLE_KERNEL_RUN_TYPE'
])
logger.info(f"Environment - Cloud: {is_cloud}, Host: {host}, Port: {port}")
def run_gradio():
"""Run Gradio in a separate thread"""
try:
gradio_app.launch(
server_name=host,
server_port=port,
share=is_cloud, # Auto-enable share in cloud environments
show_error=True,
quiet=False,
prevent_thread_lock=True # Important for running with Flask
)
except Exception as e:
logger.error(f"Failed to start Gradio: {e}")
# Start Gradio in background thread
gradio_thread = threading.Thread(target=run_gradio, daemon=True)
gradio_thread.start()
# Give Gradio time to start
time.sleep(2)
logger.info(f"βœ“ Gradio app started on {host}:{port}")
logger.info("βœ“ Flask health endpoints available at /health and /info")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down...")
if __name__ == "__main__":
try:
run_gradio_on_flask()
except Exception as e:
logger.error(f"Failed to start application: {e}")
import traceback
traceback.print_exc()
# Fallback to basic Gradio if Flask setup fails
logger.info("Falling back to basic Gradio interface...")
demo = create_gradio_app()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)