Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import os | |
| import hashlib | |
| from datetime import datetime | |
| from transformers import pipeline | |
| import soundfile as sf | |
| import torch | |
| from tenacity import retry, stop_after_attempt, wait_fixed | |
| from gtts import gTTS | |
| # Initialize local models with retry logic | |
| def load_whisper_model(): | |
| try: | |
| model = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-tiny.en", | |
| device=-1, # CPU; use device=0 for GPU if available | |
| model_kwargs={"use_safetensors": True} | |
| ) | |
| print("Whisper model loaded successfully.") | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load Whisper model: {str(e)}") | |
| raise | |
| def load_symptom_model(): | |
| try: | |
| model = pipeline( | |
| "text-classification", | |
| model="abhirajeshbhai/symptom-2-disease-net", | |
| device=-1, # CPU | |
| model_kwargs={"use_safetensors": True} | |
| ) | |
| print("Symptom-2-Disease model loaded successfully.") | |
| return model | |
| except Exception as e: | |
| print(f"Failed to load Symptom-2-Disease model: {str(e)}") | |
| # Fallback to a generic model | |
| try: | |
| model = pipeline( | |
| "text-classification", | |
| model="distilbert-base-uncased", | |
| device=-1 | |
| ) | |
| print("Fallback to distilbert-base-uncased model.") | |
| return model | |
| except Exception as fallback_e: | |
| print(f"Fallback model failed: {str(fallback_e)}") | |
| raise | |
| whisper = None | |
| symptom_classifier = None | |
| is_fallback_model = False | |
| try: | |
| whisper = load_whisper_model() | |
| except Exception as e: | |
| print(f"Whisper model initialization failed after retries: {str(e)}") | |
| try: | |
| symptom_classifier = load_symptom_model() | |
| except Exception as e: | |
| print(f"Symptom model initialization failed after retries: {str(e)}") | |
| symptom_classifier = None | |
| is_fallback_model = True # Track if fallback model is used | |
| def compute_file_hash(file_path): | |
| """Compute MD5 hash of a file to check uniqueness.""" | |
| hash_md5 = hashlib.md5() | |
| with open(file_path, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| hash_md5.update(chunk) | |
| return hash_md5.hexdigest() | |
| def transcribe_audio(audio_file): | |
| """Transcribe audio using local Whisper model.""" | |
| if not whisper: | |
| return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources." | |
| try: | |
| # Load and validate audio | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| if len(audio) < 1600: # Less than 0.1s | |
| return "Error: Audio too short. Please provide audio of at least 1 second." | |
| if np.max(np.abs(audio)) < 1e-4: # Too quiet | |
| return "Error: Audio too quiet. Please provide clear audio describing symptoms in English." | |
| # Save as WAV for Whisper | |
| temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav" | |
| sf.write(temp_wav, audio, sr) | |
| # Transcribe with beam search | |
| with torch.no_grad(): | |
| result = whisper(temp_wav, generate_kwargs={"num_beams": 5}) | |
| transcription = result.get("text", "").strip() | |
| print(f"Transcription: {transcription}") | |
| # Clean up temp file | |
| try: | |
| os.remove(temp_wav) | |
| except Exception: | |
| pass | |
| if not transcription: | |
| return "Transcription empty. Please provide clear audio describing symptoms in English." | |
| # Check for repetitive transcription | |
| words = transcription.split() | |
| if len(words) > 5 and len(set(words)) < len(words) / 2: | |
| return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms." | |
| return transcription | |
| except Exception as e: | |
| return f"Error transcribing audio: {str(e)}" | |
| def analyze_symptoms(text): | |
| """Analyze symptoms using local Symptom-2-Disease model.""" | |
| if not symptom_classifier: | |
| return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0 | |
| try: | |
| if not text or "Error transcribing" in text: | |
| return "No valid transcription for analysis.", 0.0 | |
| with torch.no_grad(): | |
| result = symptom_classifier(text) | |
| if result and isinstance(result, list) and len(result) > 0: | |
| prediction = result[0]["label"] | |
| score = result[0]["score"] | |
| if is_fallback_model: | |
| print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.") | |
| prediction = f"{prediction} (using fallback model)" | |
| print(f"Health Prediction: {prediction}, Score: {score:.4f}") | |
| return prediction, score | |
| return "No health condition predicted", 0.0 | |
| except Exception as e: | |
| return f"Error analyzing symptoms: {str(e)}", 0.0 | |
| def generate_voice_feedback(text): | |
| """Generate voice feedback from text using gTTS.""" | |
| try: | |
| # Remove debug info and disclaimer for cleaner voice output | |
| clean_text = text.split("\n\n**Debug Info**")[0] | |
| tts = gTTS(text=clean_text, lang='en') | |
| output_file = f"/tmp/feedback_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.mp3" | |
| tts.save(output_file) | |
| return output_file | |
| except Exception as e: | |
| print(f"Error generating voice feedback: {str(e)}") | |
| return None | |
| def analyze_voice(audio_file): | |
| """Analyze voice for health indicators and provide text and voice feedback.""" | |
| try: | |
| # Check if audio_file is None | |
| if audio_file is None: | |
| feedback = "Error: No audio provided. Please record or upload a valid audio file." | |
| voice_file = generate_voice_feedback(feedback) | |
| return feedback, voice_file | |
| # Ensure unique file name to avoid Gradio reuse | |
| unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}" | |
| os.rename(audio_file, unique_path) | |
| audio_file = unique_path | |
| # Log audio file info | |
| file_hash = compute_file_hash(audio_file) | |
| print(f"Processing audio file: {audio_file}, Hash: {file_hash}") | |
| # Load audio to verify format | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}") | |
| # Transcribe audio | |
| transcription = transcribe_audio(audio_file) | |
| if "Error transcribing" in transcription: | |
| voice_file = generate_voice_feedback(transcription) | |
| return transcription, voice_file | |
| # Check for medication-related queries | |
| if "medicine" in transcription.lower() or "treatment" in transcription.lower(): | |
| feedback = "Error: This tool does not provide medication or treatment advice. Please describe symptoms only (e.g., 'I have a fever')." | |
| feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}" | |
| feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." | |
| voice_file = generate_voice_feedback(feedback) | |
| return feedback, voice_file | |
| # Analyze symptoms | |
| prediction, score = analyze_symptoms(transcription) | |
| if "Error analyzing" in prediction: | |
| voice_file = generate_voice_feedback(prediction) | |
| return prediction, voice_file | |
| # Generate feedback | |
| if prediction == "No health condition predicted": | |
| feedback = "No significant health indicators detected." | |
| else: | |
| feedback = f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor." | |
| feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', Prediction = {prediction}, Confidence = {score:.4f}, File Hash = {file_hash}" | |
| feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." | |
| # Generate voice feedback | |
| voice_file = generate_voice_feedback(feedback) | |
| # Clean up temporary audio file | |
| try: | |
| os.remove(audio_file) | |
| print(f"Deleted temporary audio file: {audio_file}") | |
| except Exception as e: | |
| print(f"Failed to delete audio file: {str(e)}") | |
| return feedback, voice_file | |
| except Exception as e: | |
| feedback = f"Error processing audio: {str(e)}" | |
| voice_file = generate_voice_feedback(feedback) | |
| return feedback, voice_file | |
| def test_with_sample_audio(): | |
| """Test the app with sample audio files.""" | |
| samples = ["audio_samples/sample.wav", "audio_samples/common_voice_en.wav"] | |
| results = [] | |
| for sample in samples: | |
| if os.path.exists(sample): | |
| text, voice = analyze_voice(sample) | |
| results.append(f"Text: {text}\nVoice: {voice}") | |
| else: | |
| results.append(f"Sample not found: {sample}") | |
| return "\n".join(results) | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=analyze_voice, | |
| inputs=gr.Audio(type="filepath", label="Record or Upload Voice"), | |
| outputs=[ | |
| gr.Textbox(label="Health Assessment Feedback"), | |
| gr.Audio(label="Voice Feedback", type="filepath") | |
| ], | |
| title="Health Voice Analyzer", | |
| description="Record or upload a voice sample describing symptoms (e.g., 'I have a fever') for preliminary health assessment. Supports English only. Use clear audio (WAV, 16kHz). Do not ask for medication or treatment advice." | |
| ) | |
| if __name__ == "__main__": | |
| print(test_with_sample_audio()) | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |