Ellie5757575757's picture
Update app.py
a97ba2d verified
raw
history blame
14.6 kB
import gradio as gr
import json
import os
import tempfile
import logging
import traceback
from pathlib import Path
print("Gradio version:", gr.__version__)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration - Use current directory for model files
MODEL_DIR = "."
SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
def safe_import_modules():
"""Safely import pipeline modules with error handling"""
modules = {}
try:
from utils_audio import convert_to_wav
modules['convert_to_wav'] = convert_to_wav
logger.info("βœ“ utils_audio imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import utils_audio: {e}")
modules['convert_to_wav'] = None
try:
from to_cha import to_cha_from_wav
modules['to_cha_from_wav'] = to_cha_from_wav
logger.info("βœ“ to_cha imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import to_cha: {e}")
modules['to_cha_from_wav'] = None
try:
from cha_json import cha_to_json_file
modules['cha_to_json_file'] = cha_to_json_file
logger.info("βœ“ cha_json imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import cha_json: {e}")
modules['cha_to_json_file'] = None
try:
from output import predict_from_chajson
modules['predict_from_chajson'] = predict_from_chajson
logger.info("βœ“ output imported successfully")
except Exception as e:
logger.error(f"βœ— Failed to import output: {e}")
modules['predict_from_chajson'] = None
return modules
# Import modules
MODULES = safe_import_modules()
def check_model_files():
"""Check if required model files exist"""
required_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"tokenizer_config.json"
]
missing_files = []
for file in required_files:
if not os.path.exists(os.path.join(MODEL_DIR, file)):
missing_files.append(file)
if missing_files:
logger.error(f"Missing model files: {missing_files}")
return False, missing_files
logger.info("βœ“ All required model files found")
return True, []
def run_complete_pipeline(audio_file_path: str) -> dict:
"""Complete pipeline: Audio β†’ WAV β†’ CHA β†’ JSON β†’ Model Prediction"""
# Check if all modules are available
if not all(MODULES.values()):
missing = [k for k, v in MODULES.items() if v is None]
return {
"success": False,
"error": f"Missing required modules: {missing}",
"message": "Pipeline modules not available"
}
try:
logger.info(f"Starting pipeline for: {audio_file_path}")
# Step 1: Convert to WAV
logger.info("Step 1: Converting audio to WAV...")
wav_path = MODULES['convert_to_wav'](audio_file_path, sr=16000, mono=True)
logger.info(f"WAV conversion completed: {wav_path}")
# Step 2: Generate CHA file using Batchalign
logger.info("Step 2: Generating CHA file...")
cha_path = MODULES['to_cha_from_wav'](wav_path, lang="eng")
logger.info(f"CHA generation completed: {cha_path}")
# Step 3: Convert CHA to JSON
logger.info("Step 3: Converting CHA to JSON...")
chajson_path, json_data = MODULES['cha_to_json_file'](cha_path)
logger.info(f"JSON conversion completed: {chajson_path}")
# Step 4: Run aphasia classification
logger.info("Step 4: Running aphasia classification...")
results = MODULES['predict_from_chajson'](MODEL_DIR, chajson_path, output_file=None)
logger.info("Classification completed")
# Cleanup temporary files
try:
os.unlink(wav_path)
os.unlink(cha_path)
os.unlink(chajson_path)
except Exception as cleanup_error:
logger.warning(f"Cleanup error: {cleanup_error}")
return {
"success": True,
"results": results,
"message": "Pipeline completed successfully"
}
except Exception as e:
logger.error(f"Pipeline error: {str(e)}")
logger.error(traceback.format_exc())
return {
"success": False,
"error": str(e),
"message": f"Pipeline failed: {str(e)}"
}
def process_audio_input(audio_file):
"""Process audio file and return formatted results"""
try:
if audio_file is None:
return "❌ Error: No audio file uploaded"
# Check if pipeline is available
if not all(MODULES.values()):
return "❌ Error: Audio processing pipeline not available. Missing required modules."
# Check file format
file_path = audio_file
if hasattr(audio_file, 'name'):
file_path = audio_file.name
file_ext = Path(file_path).suffix.lower()
if file_ext not in SUPPORTED_AUDIO_FORMATS:
return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
# Run the complete pipeline
pipeline_result = run_complete_pipeline(file_path)
if not pipeline_result["success"]:
return f"❌ Pipeline Error: {pipeline_result['message']}\n\nDetails: {pipeline_result.get('error', '')}"
# Format results
results = pipeline_result["results"]
if "predictions" in results and len(results["predictions"]) > 0:
first_pred = results["predictions"][0]
if "error" in first_pred:
return f"❌ Classification Error: {first_pred['error']}"
# Format main result
predicted_class = first_pred["prediction"]["predicted_class"]
confidence = first_pred["prediction"]["confidence_percentage"]
class_name = first_pred["class_description"]["name"]
description = first_pred["class_description"]["description"]
# Additional metrics
additional_info = first_pred["additional_predictions"]
severity_level = additional_info["predicted_severity_level"]
fluency_score = additional_info["fluency_score"]
fluency_rating = additional_info["fluency_rating"]
# Format probability distribution (top 3)
prob_dist = first_pred["probability_distribution"]
top_3 = list(prob_dist.items())[:3]
result_text = f"""
🧠 **APHASIA CLASSIFICATION RESULTS**
🎯 **Primary Classification:** {predicted_class}
πŸ“Š **Confidence:** {confidence}
πŸ“‹ **Type:** {class_name}
πŸ“ˆ **Additional Metrics:**
β€’ Severity Level: {severity_level}/3
β€’ Fluency Score: {fluency_score:.3f} ({fluency_rating})
πŸ“Š **Top 3 Probability Rankings:**
"""
for i, (aphasia_type, info) in enumerate(top_3, 1):
result_text += f"{i}. {aphasia_type}: {info['percentage']}\n"
result_text += f"""
πŸ“ **Clinical Description:**
{description}
πŸ“Š **Processing Summary:**
β€’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}
β€’ Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')}
β€’ Average fluency: {results.get('summary', {}).get('average_fluency_score', 'N/A')}
"""
return result_text
else:
return "❌ No predictions generated. The audio file may not contain analyzable speech."
except Exception as e:
logger.error(f"Processing error: {str(e)}")
logger.error(traceback.format_exc())
return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details."
def process_text_input(text_input):
"""Process text input directly (fallback option)"""
try:
if not text_input or not text_input.strip():
return "❌ Error: Please enter some text for analysis"
# Check if prediction module is available
if MODULES['predict_from_chajson'] is None:
return "❌ Error: Text analysis not available. Missing prediction module."
# Create a simple JSON structure for text-only input
temp_json = {
"sentences": [{
"sentence_id": "S1",
"aphasia_type": "UNKNOWN",
"dialogues": [{
"INV": [],
"PAR": [{
"tokens": text_input.split(),
"word_pos_ids": [0] * len(text_input.split()),
"word_grammar_ids": [[0, 0, 0]] * len(text_input.split()),
"word_durations": [0.0] * len(text_input.split()),
"utterance_text": text_input
}]
}]
}],
"text_all": text_input
}
# Save to temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(temp_json, f, ensure_ascii=False, indent=2)
temp_json_path = f.name
# Run prediction
results = MODULES['predict_from_chajson'](MODEL_DIR, temp_json_path, output_file=None)
# Cleanup
try:
os.unlink(temp_json_path)
except:
pass
# Format results
if "predictions" in results and len(results["predictions"]) > 0:
first_pred = results["predictions"][0]
predicted_class = first_pred["prediction"]["predicted_class"]
confidence = first_pred["prediction"]["confidence_percentage"]
description = first_pred["class_description"]["description"]
severity = first_pred["additional_predictions"]["predicted_severity_level"]
fluency = first_pred["additional_predictions"]["fluency_rating"]
return f"""
🧠 **TEXT ANALYSIS RESULTS**
🎯 **Predicted:** {predicted_class}
πŸ“Š **Confidence:** {confidence}
πŸ“ˆ **Severity:** {severity}/3
πŸ—£οΈ **Fluency:** {fluency}
πŸ“ **Description:**
{description}
ℹ️ **Note:** Text-based analysis provides limited accuracy compared to audio analysis.
"""
else:
return "❌ No predictions generated from text input"
except Exception as e:
logger.error(f"Text processing error: {str(e)}")
return f"❌ Error: {str(e)}"
def detect_environment():
"""Detect if we're running in a cloud environment"""
# Check for common cloud environment indicators
cloud_indicators = [
'SPACE_ID', # Hugging Face Spaces
'PAPERSPACE_NOTEBOOK_REPO_ID', # Paperspace
'COLAB_GPU', # Google Colab
'KAGGLE_KERNEL_RUN_TYPE', # Kaggle
'AWS_LAMBDA_FUNCTION_NAME', # AWS Lambda
]
is_cloud = any(os.getenv(indicator) for indicator in cloud_indicators)
# Also check if we can access localhost
import socket
localhost_accessible = False
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
result = sock.connect_ex(('127.0.0.1', 7860))
localhost_accessible = (result == 0)
sock.close()
except:
localhost_accessible = False
return is_cloud, localhost_accessible
def create_interface():
"""Create Gradio interface with proper configuration"""
# Check system status
model_available, missing_files = check_model_files()
pipeline_available = all(MODULES.values())
status_message = "🟒 **System Status: Ready**" if model_available and pipeline_available else "πŸ”΄ **System Status: Issues Detected**"
if not model_available:
status_message += f"\n❌ Missing model files: {', '.join(missing_files)}"
if not pipeline_available:
missing_modules = [k for k, v in MODULES.items() if v is None]
status_message += f"\n❌ Missing modules: {', '.join(missing_modules)}"
# Create interface using simple Interface instead of Blocks to avoid JSON schema issues
audio_interface = gr.Interface(
fn=process_audio_input,
inputs=gr.File(
label="Upload Audio File (MP3, MP4, WAV, M4A, FLAC, OGG)",
file_types=["audio"]
),
outputs=gr.Textbox(
label="Analysis Results",
lines=25,
max_lines=50
),
title="🧠 Aphasia Classification System",
description="Upload audio files to analyze speech patterns and classify aphasia types",
article=f"""
<div style="margin-top: 20px;">
<h3>System Status</h3>
<p>{status_message}</p>
<h3>About</h3>
<p><strong>Pipeline:</strong> Audio β†’ WAV β†’ CHA β†’ JSON β†’ Classification</p>
<p><strong>Supported formats:</strong> MP3, MP4, WAV, M4A, FLAC, OGG</p>
<p><em>For research and clinical assessment purposes.</em></p>
</div>
"""
)
return audio_interface
if __name__ == "__main__":
try:
logger.info("Starting Aphasia Classification System...")
# Detect environment
is_cloud, localhost_accessible = detect_environment()
logger.info(f"Environment - Cloud: {is_cloud}, Localhost accessible: {localhost_accessible}")
# Create and launch interface
demo = create_interface()
# Configure launch parameters based on environment
launch_kwargs = {
"server_name": "0.0.0.0",
"server_port": 7860,
"show_error": True,
"quiet": False,
}
# Set share parameter based on environment
if is_cloud or not localhost_accessible:
launch_kwargs["share"] = True
logger.info("Running in cloud environment or localhost not accessible - enabling share")
else:
launch_kwargs["share"] = False
logger.info("Running locally - share disabled")
demo.launch(**launch_kwargs)
except Exception as e:
logger.error(f"Failed to launch app: {e}")
logger.error(traceback.format_exc())
print(f"❌ Application startup failed: {e}")