import gradio as gr from flask import Flask import os import tempfile import logging import threading import time # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration MODEL_DIR = "." SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"] def safe_import_modules(): """Safely import pipeline modules with error handling""" modules = {} try: from utils_audio import convert_to_wav modules['convert_to_wav'] = convert_to_wav logger.info("✓ utils_audio imported successfully") except Exception as e: logger.error(f"✗ Failed to import utils_audio: {e}") modules['convert_to_wav'] = None try: from to_cha import to_cha_from_wav modules['to_cha_from_wav'] = to_cha_from_wav logger.info("✓ to_cha imported successfully") except Exception as e: logger.error(f"✗ Failed to import to_cha: {e}") modules['to_cha_from_wav'] = None try: from cha_json import cha_to_json_file modules['cha_to_json_file'] = cha_to_json_file logger.info("✓ cha_json imported successfully") except Exception as e: logger.error(f"✗ Failed to import cha_json: {e}") modules['cha_to_json_file'] = None try: from output import predict_from_chajson modules['predict_from_chajson'] = predict_from_chajson logger.info("✓ output imported successfully") except Exception as e: logger.error(f"✗ Failed to import output: {e}") modules['predict_from_chajson'] = None return modules # Import modules MODULES = safe_import_modules() def check_model_files(): """Check if required model files exist""" required_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "tokenizer_config.json" ] missing_files = [] for file in required_files: if not os.path.exists(os.path.join(MODEL_DIR, file)): missing_files.append(file) return len(missing_files) == 0, missing_files def run_complete_pipeline(audio_file_path: str) -> dict: """Complete pipeline: Audio → WAV → CHA → JSON → Model Prediction""" # Check if all modules are available if not all(MODULES.values()): missing = [k for k, v in MODULES.items() if v is None] return { "success": False, "error": f"Missing required modules: {missing}", "message": "Pipeline modules not available" } try: logger.info(f"Starting pipeline for: {audio_file_path}") # Step 1: Convert to WAV logger.info("Step 1: Converting audio to WAV...") wav_path = MODULES['convert_to_wav'](audio_file_path, sr=16000, mono=True) logger.info(f"WAV conversion completed: {wav_path}") # Step 2: Generate CHA file using Batchalign logger.info("Step 2: Generating CHA file...") cha_path = MODULES['to_cha_from_wav'](wav_path, lang="eng") logger.info(f"CHA generation completed: {cha_path}") # Step 3: Convert CHA to JSON logger.info("Step 3: Converting CHA to JSON...") chajson_path, json_data = MODULES['cha_to_json_file'](cha_path) logger.info(f"JSON conversion completed: {chajson_path}") # Step 4: Run aphasia classification logger.info("Step 4: Running aphasia classification...") results = MODULES['predict_from_chajson'](MODEL_DIR, chajson_path, output_file=None) logger.info("Classification completed") # Cleanup temporary files try: os.unlink(wav_path) os.unlink(cha_path) os.unlink(chajson_path) except Exception as cleanup_error: logger.warning(f"Cleanup error: {cleanup_error}") return { "success": True, "results": results, "message": "Pipeline completed successfully" } except Exception as e: logger.error(f"Pipeline error: {str(e)}") import traceback traceback.print_exc() return { "success": False, "error": str(e), "message": f"Pipeline failed: {str(e)}" } def process_audio_input(audio_file): """Process audio file and return formatted results""" try: if audio_file is None: return "❌ Error: No audio file uploaded" # Check if pipeline is available if not all(MODULES.values()): missing_modules = [k for k, v in MODULES.items() if v is None] return f"❌ Error: Audio processing pipeline not available. Missing required modules: {', '.join(missing_modules)}" # Check file format file_path = audio_file if hasattr(audio_file, 'name'): file_path = audio_file.name from pathlib import Path file_ext = Path(file_path).suffix.lower() if file_ext not in SUPPORTED_AUDIO_FORMATS: return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}" # Run the complete pipeline pipeline_result = run_complete_pipeline(file_path) if not pipeline_result["success"]: return f"❌ Pipeline Error: {pipeline_result['message']}\n\nDetails: {pipeline_result.get('error', '')}" # Format results results = pipeline_result["results"] if "predictions" in results and len(results["predictions"]) > 0: first_pred = results["predictions"][0] if "error" in first_pred: return f"❌ Classification Error: {first_pred['error']}" # Format main result predicted_class = first_pred["prediction"]["predicted_class"] confidence = first_pred["prediction"]["confidence_percentage"] class_name = first_pred["class_description"]["name"] description = first_pred["class_description"]["description"] # Additional metrics additional_info = first_pred["additional_predictions"] severity_level = additional_info["predicted_severity_level"] fluency_score = additional_info["fluency_score"] fluency_rating = additional_info["fluency_rating"] # Format probability distribution (top 3) prob_dist = first_pred["probability_distribution"] top_3 = list(prob_dist.items())[:3] result_text = f"""🧠 **APHASIA CLASSIFICATION RESULTS** 🎯 **Primary Classification:** {predicted_class} 📊 **Confidence:** {confidence} 📋 **Type:** {class_name} 📈 **Additional Metrics:** • Severity Level: {severity_level}/3 • Fluency Score: {fluency_score:.3f} ({fluency_rating}) 📊 **Top 3 Probability Rankings:** """ for i, (aphasia_type, info) in enumerate(top_3, 1): result_text += f"{i}. {aphasia_type}: {info['percentage']}\n" result_text += f""" 📝 **Clinical Description:** {description} 📊 **Processing Summary:** • Total sentences analyzed: {results.get('total_sentences', 'N/A')} • Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')} • Average fluency: {results.get('summary', {}).get('average_fluency_score', 'N/A')} """ return result_text else: return "❌ No predictions generated. The audio file may not contain analyzable speech." except Exception as e: logger.error(f"Processing error: {str(e)}") import traceback traceback.print_exc() return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details." def process_text_input(text_input): """Process text input directly (fallback option)""" try: if not text_input or not text_input.strip(): return "❌ Error: Please enter some text for analysis" # Check if prediction module is available if MODULES['predict_from_chajson'] is None: return "❌ Error: Text analysis not available. Missing prediction module." # Create a simple JSON structure for text-only input import json temp_json = { "sentences": [{ "sentence_id": "S1", "aphasia_type": "UNKNOWN", "dialogues": [{ "INV": [], "PAR": [{ "tokens": text_input.split(), "word_pos_ids": [0] * len(text_input.split()), "word_grammar_ids": [[0, 0, 0]] * len(text_input.split()), "word_durations": [0.0] * len(text_input.split()), "utterance_text": text_input }] }] }], "text_all": text_input } # Save to temporary file with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(temp_json, f, ensure_ascii=False, indent=2) temp_json_path = f.name # Run prediction results = MODULES['predict_from_chajson'](MODEL_DIR, temp_json_path, output_file=None) # Cleanup try: os.unlink(temp_json_path) except: pass # Format results if "predictions" in results and len(results["predictions"]) > 0: first_pred = results["predictions"][0] predicted_class = first_pred["prediction"]["predicted_class"] confidence = first_pred["prediction"]["confidence_percentage"] description = first_pred["class_description"]["description"] severity = first_pred["additional_predictions"]["predicted_severity_level"] fluency = first_pred["additional_predictions"]["fluency_rating"] return f"""🧠 **TEXT ANALYSIS RESULTS** 🎯 **Predicted:** {predicted_class} 📊 **Confidence:** {confidence} 📈 **Severity:** {severity}/3 🗣️ **Fluency:** {fluency} 📝 **Description:** {description} ℹ️ **Note:** Text-based analysis provides limited accuracy compared to audio analysis. """ else: return "❌ No predictions generated from text input" except Exception as e: logger.error(f"Text processing error: {str(e)}") return f"❌ Error: {str(e)}" def create_gradio_app(): """Create the Gradio interface""" # Check system status model_available, missing_files = check_model_files() pipeline_available = all(MODULES.values()) status_message = "🟢 System Ready" if model_available and pipeline_available else "🔴 System Issues" status_details = [] if not model_available: status_details.append(f"Missing model files: {', '.join(missing_files)}") if not pipeline_available: missing_modules = [k for k, v in MODULES.items() if v is None] status_details.append(f"Missing modules: {', '.join(missing_modules)}") # Create simple interfaces to avoid JSON schema issues audio_demo = gr.Interface( fn=process_audio_input, inputs=gr.File(label="Upload Audio File", file_types=["audio"]), outputs=gr.Textbox(label="Analysis Results", lines=25), title="🎵 Audio Analysis", description="Upload MP3, MP4, WAV, M4A, FLAC, or OGG files" ) text_demo = gr.Interface( fn=process_text_input, inputs=gr.Textbox(label="Enter Text", lines=5, placeholder="Enter speech transcription..."), outputs=gr.Textbox(label="Analysis Results", lines=15), title="📝 Text Analysis", description="Enter text for direct analysis (less accurate than audio)" ) # Combine interfaces using TabbedInterface demo = gr.TabbedInterface( [audio_demo, text_demo], ["Audio Analysis", "Text Analysis"], title="🧠 Aphasia Classification System", theme=gr.themes.Soft() ) return demo def create_flask_app(): """Create Flask app that serves Gradio""" # Create Flask app flask_app = Flask(__name__) # Create Gradio app gradio_app = create_gradio_app() # Mount Gradio app on Flask gradio_app.queue() # Enable queuing for better performance # Get the underlying FastAPI app from Gradio gradio_fastapi_app = gradio_app.app # Add a health check endpoint @flask_app.route('/health') def health_check(): model_available, missing_files = check_model_files() pipeline_available = all(MODULES.values()) return { "status": "healthy" if model_available and pipeline_available else "unhealthy", "model_available": model_available, "pipeline_available": pipeline_available, "missing_files": missing_files if not model_available else [], "missing_modules": [k for k, v in MODULES.items() if v is None] if not pipeline_available else [] } # Add info endpoint @flask_app.route('/info') def info(): return { "title": "Aphasia Classification System", "description": "AI-powered aphasia type classification from audio", "supported_formats": SUPPORTED_AUDIO_FORMATS, "endpoints": { "/": "Main Gradio interface", "/health": "Health check", "/info": "System information" } } return flask_app, gradio_app def run_gradio_on_flask(): """Run Gradio app mounted on Flask""" logger.info("Starting Aphasia Classification System with Flask + Gradio...") # Create Flask and Gradio apps flask_app, gradio_app = create_flask_app() # Detect environment port = int(os.environ.get('PORT', 7860)) host = os.environ.get('HOST', '0.0.0.0') # Check if we're in a cloud environment is_cloud = any(os.getenv(indicator) for indicator in [ 'SPACE_ID', 'PAPERSPACE_NOTEBOOK_REPO_ID', 'COLAB_GPU', 'KAGGLE_KERNEL_RUN_TYPE' ]) logger.info(f"Environment - Cloud: {is_cloud}, Host: {host}, Port: {port}") def run_gradio(): """Run Gradio in a separate thread""" try: gradio_app.launch( server_name=host, server_port=port, share=is_cloud, # Auto-enable share in cloud environments show_error=True, quiet=False, prevent_thread_lock=True # Important for running with Flask ) except Exception as e: logger.error(f"Failed to start Gradio: {e}") # Start Gradio in background thread gradio_thread = threading.Thread(target=run_gradio, daemon=True) gradio_thread.start() # Give Gradio time to start time.sleep(2) logger.info(f"✓ Gradio app started on {host}:{port}") logger.info("✓ Flask health endpoints available at /health and /info") # Keep the main thread alive try: while True: time.sleep(1) except KeyboardInterrupt: logger.info("Shutting down...") if __name__ == "__main__": try: run_gradio_on_flask() except Exception as e: logger.error(f"Failed to start application: {e}") import traceback traceback.print_exc() # Fallback to basic Gradio if Flask setup fails logger.info("Falling back to basic Gradio interface...") demo = create_gradio_app() demo.launch( server_name="0.0.0.0", server_port=7860, share=True, show_error=True )