Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from flask import Flask | |
| import os | |
| import tempfile | |
| import logging | |
| import threading | |
| import time | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_DIR = "." | |
| SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"] | |
| def safe_import_modules(): | |
| """Safely import pipeline modules with error handling""" | |
| modules = {} | |
| try: | |
| from utils_audio import convert_to_wav | |
| modules['convert_to_wav'] = convert_to_wav | |
| logger.info("β utils_audio imported successfully") | |
| except Exception as e: | |
| logger.error(f"β Failed to import utils_audio: {e}") | |
| modules['convert_to_wav'] = None | |
| try: | |
| from to_cha import to_cha_from_wav | |
| modules['to_cha_from_wav'] = to_cha_from_wav | |
| logger.info("β to_cha imported successfully") | |
| except Exception as e: | |
| logger.error(f"β Failed to import to_cha: {e}") | |
| modules['to_cha_from_wav'] = None | |
| try: | |
| from cha_json import cha_to_json_file | |
| modules['cha_to_json_file'] = cha_to_json_file | |
| logger.info("β cha_json imported successfully") | |
| except Exception as e: | |
| logger.error(f"β Failed to import cha_json: {e}") | |
| modules['cha_to_json_file'] = None | |
| try: | |
| from output import predict_from_chajson | |
| modules['predict_from_chajson'] = predict_from_chajson | |
| logger.info("β output imported successfully") | |
| except Exception as e: | |
| logger.error(f"β Failed to import output: {e}") | |
| modules['predict_from_chajson'] = None | |
| return modules | |
| # Import modules | |
| MODULES = safe_import_modules() | |
| def check_model_files(): | |
| """Check if required model files exist""" | |
| required_files = [ | |
| "pytorch_model.bin", | |
| "config.json", | |
| "tokenizer.json", | |
| "tokenizer_config.json" | |
| ] | |
| missing_files = [] | |
| for file in required_files: | |
| if not os.path.exists(os.path.join(MODEL_DIR, file)): | |
| missing_files.append(file) | |
| return len(missing_files) == 0, missing_files | |
| def run_complete_pipeline(audio_file_path: str) -> dict: | |
| """Complete pipeline: Audio β WAV β CHA β JSON β Model Prediction""" | |
| # Check if all modules are available | |
| if not all(MODULES.values()): | |
| missing = [k for k, v in MODULES.items() if v is None] | |
| return { | |
| "success": False, | |
| "error": f"Missing required modules: {missing}", | |
| "message": "Pipeline modules not available" | |
| } | |
| try: | |
| logger.info(f"Starting pipeline for: {audio_file_path}") | |
| # Step 1: Convert to WAV | |
| logger.info("Step 1: Converting audio to WAV...") | |
| wav_path = MODULES['convert_to_wav'](audio_file_path, sr=16000, mono=True) | |
| logger.info(f"WAV conversion completed: {wav_path}") | |
| # Step 2: Generate CHA file using Batchalign | |
| logger.info("Step 2: Generating CHA file...") | |
| cha_path = MODULES['to_cha_from_wav'](wav_path, lang="eng") | |
| logger.info(f"CHA generation completed: {cha_path}") | |
| # Step 3: Convert CHA to JSON | |
| logger.info("Step 3: Converting CHA to JSON...") | |
| chajson_path, json_data = MODULES['cha_to_json_file'](cha_path) | |
| logger.info(f"JSON conversion completed: {chajson_path}") | |
| # Step 4: Run aphasia classification | |
| logger.info("Step 4: Running aphasia classification...") | |
| results = MODULES['predict_from_chajson'](MODEL_DIR, chajson_path, output_file=None) | |
| logger.info("Classification completed") | |
| # Cleanup temporary files | |
| try: | |
| os.unlink(wav_path) | |
| os.unlink(cha_path) | |
| os.unlink(chajson_path) | |
| except Exception as cleanup_error: | |
| logger.warning(f"Cleanup error: {cleanup_error}") | |
| return { | |
| "success": True, | |
| "results": results, | |
| "message": "Pipeline completed successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Pipeline error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "message": f"Pipeline failed: {str(e)}" | |
| } | |
| def process_audio_input(audio_file): | |
| """Process audio file and return formatted results""" | |
| try: | |
| if audio_file is None: | |
| return "β Error: No audio file uploaded" | |
| # Check if pipeline is available | |
| if not all(MODULES.values()): | |
| missing_modules = [k for k, v in MODULES.items() if v is None] | |
| return f"β Error: Audio processing pipeline not available. Missing required modules: {', '.join(missing_modules)}" | |
| # Check file format | |
| file_path = audio_file | |
| if hasattr(audio_file, 'name'): | |
| file_path = audio_file.name | |
| from pathlib import Path | |
| file_ext = Path(file_path).suffix.lower() | |
| if file_ext not in SUPPORTED_AUDIO_FORMATS: | |
| return f"β Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}" | |
| # Run the complete pipeline | |
| pipeline_result = run_complete_pipeline(file_path) | |
| if not pipeline_result["success"]: | |
| return f"β Pipeline Error: {pipeline_result['message']}\n\nDetails: {pipeline_result.get('error', '')}" | |
| # Format results | |
| results = pipeline_result["results"] | |
| if "predictions" in results and len(results["predictions"]) > 0: | |
| first_pred = results["predictions"][0] | |
| if "error" in first_pred: | |
| return f"β Classification Error: {first_pred['error']}" | |
| # Format main result | |
| predicted_class = first_pred["prediction"]["predicted_class"] | |
| confidence = first_pred["prediction"]["confidence_percentage"] | |
| class_name = first_pred["class_description"]["name"] | |
| description = first_pred["class_description"]["description"] | |
| # Additional metrics | |
| additional_info = first_pred["additional_predictions"] | |
| severity_level = additional_info["predicted_severity_level"] | |
| fluency_score = additional_info["fluency_score"] | |
| fluency_rating = additional_info["fluency_rating"] | |
| # Format probability distribution (top 3) | |
| prob_dist = first_pred["probability_distribution"] | |
| top_3 = list(prob_dist.items())[:3] | |
| result_text = f"""π§ **APHASIA CLASSIFICATION RESULTS** | |
| π― **Primary Classification:** {predicted_class} | |
| π **Confidence:** {confidence} | |
| π **Type:** {class_name} | |
| π **Additional Metrics:** | |
| β’ Severity Level: {severity_level}/3 | |
| β’ Fluency Score: {fluency_score:.3f} ({fluency_rating}) | |
| π **Top 3 Probability Rankings:** | |
| """ | |
| for i, (aphasia_type, info) in enumerate(top_3, 1): | |
| result_text += f"{i}. {aphasia_type}: {info['percentage']}\n" | |
| result_text += f""" | |
| π **Clinical Description:** | |
| {description} | |
| π **Processing Summary:** | |
| β’ Total sentences analyzed: {results.get('total_sentences', 'N/A')} | |
| β’ Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')} | |
| β’ Average fluency: {results.get('summary', {}).get('average_fluency_score', 'N/A')} | |
| """ | |
| return result_text | |
| else: | |
| return "β No predictions generated. The audio file may not contain analyzable speech." | |
| except Exception as e: | |
| logger.error(f"Processing error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"β Processing Error: {str(e)}\n\nPlease check the logs for more details." | |
| def process_text_input(text_input): | |
| """Process text input directly (fallback option)""" | |
| try: | |
| if not text_input or not text_input.strip(): | |
| return "β Error: Please enter some text for analysis" | |
| # Check if prediction module is available | |
| if MODULES['predict_from_chajson'] is None: | |
| return "β Error: Text analysis not available. Missing prediction module." | |
| # Create a simple JSON structure for text-only input | |
| import json | |
| temp_json = { | |
| "sentences": [{ | |
| "sentence_id": "S1", | |
| "aphasia_type": "UNKNOWN", | |
| "dialogues": [{ | |
| "INV": [], | |
| "PAR": [{ | |
| "tokens": text_input.split(), | |
| "word_pos_ids": [0] * len(text_input.split()), | |
| "word_grammar_ids": [[0, 0, 0]] * len(text_input.split()), | |
| "word_durations": [0.0] * len(text_input.split()), | |
| "utterance_text": text_input | |
| }] | |
| }] | |
| }], | |
| "text_all": text_input | |
| } | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: | |
| json.dump(temp_json, f, ensure_ascii=False, indent=2) | |
| temp_json_path = f.name | |
| # Run prediction | |
| results = MODULES['predict_from_chajson'](MODEL_DIR, temp_json_path, output_file=None) | |
| # Cleanup | |
| try: | |
| os.unlink(temp_json_path) | |
| except: | |
| pass | |
| # Format results | |
| if "predictions" in results and len(results["predictions"]) > 0: | |
| first_pred = results["predictions"][0] | |
| predicted_class = first_pred["prediction"]["predicted_class"] | |
| confidence = first_pred["prediction"]["confidence_percentage"] | |
| description = first_pred["class_description"]["description"] | |
| severity = first_pred["additional_predictions"]["predicted_severity_level"] | |
| fluency = first_pred["additional_predictions"]["fluency_rating"] | |
| return f"""π§ **TEXT ANALYSIS RESULTS** | |
| π― **Predicted:** {predicted_class} | |
| π **Confidence:** {confidence} | |
| π **Severity:** {severity}/3 | |
| π£οΈ **Fluency:** {fluency} | |
| π **Description:** | |
| {description} | |
| βΉοΈ **Note:** Text-based analysis provides limited accuracy compared to audio analysis. | |
| """ | |
| else: | |
| return "β No predictions generated from text input" | |
| except Exception as e: | |
| logger.error(f"Text processing error: {str(e)}") | |
| return f"β Error: {str(e)}" | |
| def create_gradio_app(): | |
| """Create the Gradio interface""" | |
| # Check system status | |
| model_available, missing_files = check_model_files() | |
| pipeline_available = all(MODULES.values()) | |
| status_message = "π’ System Ready" if model_available and pipeline_available else "π΄ System Issues" | |
| status_details = [] | |
| if not model_available: | |
| status_details.append(f"Missing model files: {', '.join(missing_files)}") | |
| if not pipeline_available: | |
| missing_modules = [k for k, v in MODULES.items() if v is None] | |
| status_details.append(f"Missing modules: {', '.join(missing_modules)}") | |
| # Create simple interfaces to avoid JSON schema issues | |
| audio_demo = gr.Interface( | |
| fn=process_audio_input, | |
| inputs=gr.File(label="Upload Audio File", file_types=["audio"]), | |
| outputs=gr.Textbox(label="Analysis Results", lines=25), | |
| title="π΅ Audio Analysis", | |
| description="Upload MP3, MP4, WAV, M4A, FLAC, or OGG files" | |
| ) | |
| text_demo = gr.Interface( | |
| fn=process_text_input, | |
| inputs=gr.Textbox(label="Enter Text", lines=5, placeholder="Enter speech transcription..."), | |
| outputs=gr.Textbox(label="Analysis Results", lines=15), | |
| title="π Text Analysis", | |
| description="Enter text for direct analysis (less accurate than audio)" | |
| ) | |
| # Combine interfaces using TabbedInterface | |
| demo = gr.TabbedInterface( | |
| [audio_demo, text_demo], | |
| ["Audio Analysis", "Text Analysis"], | |
| title="π§ Aphasia Classification System", | |
| theme=gr.themes.Soft() | |
| ) | |
| return demo | |
| def create_flask_app(): | |
| """Create Flask app that serves Gradio""" | |
| # Create Flask app | |
| flask_app = Flask(__name__) | |
| # Create Gradio app | |
| gradio_app = create_gradio_app() | |
| # Mount Gradio app on Flask | |
| gradio_app.queue() # Enable queuing for better performance | |
| # Get the underlying FastAPI app from Gradio | |
| gradio_fastapi_app = gradio_app.app | |
| # Add a health check endpoint | |
| def health_check(): | |
| model_available, missing_files = check_model_files() | |
| pipeline_available = all(MODULES.values()) | |
| return { | |
| "status": "healthy" if model_available and pipeline_available else "unhealthy", | |
| "model_available": model_available, | |
| "pipeline_available": pipeline_available, | |
| "missing_files": missing_files if not model_available else [], | |
| "missing_modules": [k for k, v in MODULES.items() if v is None] if not pipeline_available else [] | |
| } | |
| # Add info endpoint | |
| def info(): | |
| return { | |
| "title": "Aphasia Classification System", | |
| "description": "AI-powered aphasia type classification from audio", | |
| "supported_formats": SUPPORTED_AUDIO_FORMATS, | |
| "endpoints": { | |
| "/": "Main Gradio interface", | |
| "/health": "Health check", | |
| "/info": "System information" | |
| } | |
| } | |
| return flask_app, gradio_app | |
| def run_gradio_on_flask(): | |
| """Run Gradio app mounted on Flask""" | |
| logger.info("Starting Aphasia Classification System with Flask + Gradio...") | |
| # Create Flask and Gradio apps | |
| flask_app, gradio_app = create_flask_app() | |
| # Detect environment | |
| port = int(os.environ.get('PORT', 7860)) | |
| host = os.environ.get('HOST', '0.0.0.0') | |
| # Check if we're in a cloud environment | |
| is_cloud = any(os.getenv(indicator) for indicator in [ | |
| 'SPACE_ID', 'PAPERSPACE_NOTEBOOK_REPO_ID', | |
| 'COLAB_GPU', 'KAGGLE_KERNEL_RUN_TYPE' | |
| ]) | |
| logger.info(f"Environment - Cloud: {is_cloud}, Host: {host}, Port: {port}") | |
| def run_gradio(): | |
| """Run Gradio in a separate thread""" | |
| try: | |
| gradio_app.launch( | |
| server_name=host, | |
| server_port=port, | |
| share=is_cloud, # Auto-enable share in cloud environments | |
| show_error=True, | |
| quiet=False, | |
| prevent_thread_lock=True # Important for running with Flask | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start Gradio: {e}") | |
| # Start Gradio in background thread | |
| gradio_thread = threading.Thread(target=run_gradio, daemon=True) | |
| gradio_thread.start() | |
| # Give Gradio time to start | |
| time.sleep(2) | |
| logger.info(f"β Gradio app started on {host}:{port}") | |
| logger.info("β Flask health endpoints available at /health and /info") | |
| # Keep the main thread alive | |
| try: | |
| while True: | |
| time.sleep(1) | |
| except KeyboardInterrupt: | |
| logger.info("Shutting down...") | |
| if __name__ == "__main__": | |
| try: | |
| run_gradio_on_flask() | |
| except Exception as e: | |
| logger.error(f"Failed to start application: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Fallback to basic Gradio if Flask setup fails | |
| logger.info("Falling back to basic Gradio interface...") | |
| demo = create_gradio_app() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |