Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import os | |
| import hashlib | |
| from datetime import datetime | |
| import soundfile as sf | |
| import torch | |
| from tenacity import retry, stop_after_attempt, wait_fixed | |
| from transformers import pipeline | |
| # Initialize local models with retry logic | |
| def load_whisper_model(): | |
| try: | |
| model = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-tiny", # Multilingual model | |
| device=-1, # CPU; use device=0 for GPU if available | |
| model_kwargs={"use_safetensors": True} | |
| ) | |
| print("Whisper model loaded successfully.") | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load Whisper model: {str(e)}") | |
| raise | |
| def load_symptom_model(): | |
| try: | |
| model = pipeline( | |
| "text-classification", | |
| model="abhirajeshbhai/symptom-2-disease-net", | |
| device=-1, # CPU | |
| model_kwargs={"use_safetensors": True} | |
| ) | |
| print("Symptom-2-Disease model loaded successfully.") | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load Symptom-2-Disease model: {str(e)}") | |
| # Fallback to a generic model | |
| try: | |
| model = pipeline( | |
| "text-classification", | |
| model="distilbert-base-uncased", | |
| device=-1 | |
| ) | |
| print("Fallback to distilbert-base-uncased model.") | |
| return model | |
| except Exception as fallback_e: | |
| print(f"Fallback model failed: {str(fallback_e)}") | |
| raise | |
| whisper = None | |
| symptom_classifier = None | |
| is_fallback_model = False | |
| try: | |
| whisper = load_whisper_model() | |
| except Exception as e: | |
| print(f"Whisper model initialization failed after retries: {str(e)}") | |
| try: | |
| symptom_classifier = load_symptom_model() | |
| except Exception as e: | |
| print(f"Symptom model initialization failed after retries: {str(e)}") | |
| symptom_classifier = None | |
| is_fallback_model = True | |
| def compute_file_hash(file_path): | |
| """Compute MD5 hash of a file to check uniqueness.""" | |
| hash_md5 = hashlib.md5() | |
| with open(file_path, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest() | |
| def transcribe_audio(audio_file, language="en"): | |
| """Transcribe audio using local Whisper model.""" | |
| if not whisper: | |
| return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources." | |
| try: | |
| # Load and validate audio | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| if len(audio) < 1600: # Less than 0.1s | |
| return "Error: Audio too short. Please provide audio of at least 1 second." | |
| if np.max(np.abs(audio)) < 1e-4: # Too quiet | |
| return "Error: Audio too quiet. Please provide clear audio describing symptoms." | |
| # Save as WAV for Whisper | |
| temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav" | |
| sf.write(temp_wav, audio, sr) | |
| # Transcribe with beam search and language | |
| with torch.no_grad(): | |
| result = whisper(temp_wav, generate_kwargs={"num_beams": 5, "language": language}) | |
| transcription = result.get("text", "").strip() | |
| print(f"Transcription: {transcription}") | |
| # Clean up temp file | |
| try: | |
| os.remove(temp_wav) | |
| except Exception: | |
| pass | |
| if not transcription: | |
| return "Transcription empty. Please provide clear audio describing symptoms." | |
| # Check for repetitive transcription | |
| words = transcription.split() | |
| if len(words) > 5 and len(set(words)) < len(words) / 2: | |
| return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms." | |
| return transcription | |
| except Exception as e: | |
| return f"Error transcribing audio: {str(e)}" | |
| def analyze_symptoms(text): | |
| """Analyze symptoms using local Symptom-2-Disease model.""" | |
| if not symptom_classifier: | |
| return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0 | |
| try: | |
| if not text or "Error transcribing" in text: | |
| return "No valid transcription for analysis.", 0.0 | |
| with torch.no_grad(): | |
| result = symptom_classifier(text) | |
| if result and isinstance(result, list) and len(result) > 0: | |
| prediction = result[0]["label"] | |
| score = result[0]["score"] | |
| if is_fallback_model: | |
| print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.") | |
| prediction = f"{prediction} (using fallback model)" | |
| print(f"Health Prediction: {prediction}, Score: {score:.4f}") | |
| return prediction, score | |
| return "No health condition predicted", 0.0 | |
| except Exception as e: | |
| return f"Error analyzing symptoms: {str(e)}", 0.0 | |
| def handle_health_query(query, language="en"): | |
| """Handle health-related queries with a general response.""" | |
| if not query: | |
| return "Please provide a valid health query." | |
| # Placeholder for Q&A logic (could integrate a model like BERT for Q&A) | |
| restricted_terms = ["medicine", "treatment", "drug", "prescription"] | |
| if any(term in query.lower() for term in restricted_terms): | |
| return "This tool does not provide medication or treatment advice. Please ask about symptoms or general health information (e.g., 'What are symptoms of asthma?')." | |
| return f"Response to query '{query}': For accurate health information, consult a healthcare provider." | |
| def analyze_voice(audio_file, language="en"): | |
| """Analyze voice for health indicators and handle queries.""" | |
| try: | |
| # Ensure unique file name | |
| unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}" | |
| os.rename(audio_file, unique_path) | |
| audio_file = unique_path | |
| # Log audio file info | |
| file_hash = compute_file_hash(audio_file) | |
| print(f"Processing audio file: {audio_file}, Hash: {file_hash}") | |
| # Load audio to verify format | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}") | |
| # Transcribe audio | |
| transcription = transcribe_audio(audio_file, language) | |
| if "Error transcribing" in transcription: | |
| return transcription | |
| # Split transcription into symptom and query parts | |
| symptom_text = transcription | |
| query_text = None | |
| restricted_terms = ["medicine", "treatment", "drug", "prescription"] | |
| for term in restricted_terms: | |
| if term in transcription.lower(): | |
| # Split at the first restricted term | |
| split_index = transcription.lower().find(term) | |
| symptom_text = transcription[:split_index].strip() | |
| query_text = transcription[split_index:].strip() | |
| break | |
| feedback = "" | |
| # Analyze symptoms if present | |
| if symptom_text: | |
| prediction, score = analyze_symptoms(symptom_text) | |
| if "Error analyzing" in prediction: | |
| feedback += prediction + "\n" | |
| elif prediction == "No health condition predicted": | |
| feedback += "No significant health indicators detected.\n" | |
| else: | |
| feedback += f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor.\n" | |
| else: | |
| feedback += "No symptoms detected in the audio.\n" | |
| # Handle query if present | |
| if query_text: | |
| feedback += f"\nQuery detected: '{query_text}'\n" | |
| feedback += handle_health_query(query_text, language) + "\n" | |
| # Add debug info and disclaimer | |
| feedback += f"\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}" | |
| feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." | |
| # Clean up temporary audio file | |
| try: | |
| os.remove(audio_file) | |
| print(f"Deleted temporary audio file: {audio_file}") | |
| except Exception as e: | |
| print(f"Failed to delete audio file: {str(e)}") | |
| return feedback | |
| except Exception as e: | |
| return f"Error processing audio: {str(e)}" | |
| # Gradio interface | |
| def create_gradio_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Health Voice Analyzer | |
| Record or upload a voice sample describing symptoms in English, Spanish, Hindi, or Mandarin (e.g., 'I have a fever'). | |
| Ask health questions in the text box below (e.g., 'What are symptoms of asthma?'). | |
| **Note**: Do not ask for medication or treatment advice; focus on symptoms or general health questions. | |
| **Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice. | |
| **Text-to-Speech**: Available in the web frontend (Salesforce Sites) using the browser's Web Speech API. | |
| """ | |
| ) | |
| with gr.Row(): | |
| language = gr.Dropdown( | |
| choices=["en", "es", "hi", "zh"], | |
| label="Select Language", | |
| value="en" | |
| ) | |
| with gr.Row(): | |
| audio_input = gr.Audio(type="filepath", label="Record or Upload Voice") | |
| with gr.Row(): | |
| query_input = gr.Textbox(label="Ask a Health Question (e.g., 'What are symptoms of asthma?')") | |
| with gr.Row(): | |
| output = gr.Textbox(label="Health Assessment Feedback") | |
| with gr.Row(): | |
| analyze_button = gr.Button("Analyze Voice") | |
| query_button = gr.Button("Submit Query") | |
| analyze_button.click( | |
| fn=analyze_voice, | |
| inputs=[audio_input, language], | |
| outputs=output | |
| ) | |
| query_button.click( | |
| fn=handle_health_query, | |
| inputs=[query_input, language], | |
| outputs=output | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |