Spaces:
Sleeping
Sleeping
| from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor | |
| from transformers.pipelines import pipeline | |
| import torch | |
| import torchaudio.transforms as T | |
| import numpy as np | |
| import json | |
| # Initialize Whisper components globally (these are lightweight) | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en") | |
| tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en") | |
| processor = WhisperProcessor(feature_extractor, tokenizer) | |
| # Update transcription handler | |
| def update_live_transcription(audio): | |
| """Real-time transcription updates.""" | |
| print("update_live_transcription called with:", type(audio)) | |
| if not audio or not isinstance(audio, tuple): | |
| return "" | |
| try: | |
| sample_rate, audio_array = audio | |
| print(f"got audio tuple – sample_rate={sample_rate}, shape={audio_array.shape}") | |
| def process_audio(audio_array, sample_rate): | |
| """Pre-process audio for Whisper.""" | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Convert to tensor for resampling | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = T.Resample(sample_rate, 16000) | |
| audio_tensor = resampler(audio_tensor) | |
| # Normalize | |
| audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) | |
| # Convert back to numpy array and return in correct format | |
| return { | |
| "raw": audio_tensor.numpy(), # Key must be "raw" | |
| "sampling_rate": 16000 # Key must be "sampling_rate" | |
| } | |
| features = process_audio(audio_array, sample_rate) | |
| asr = get_asr_pipeline() | |
| result = asr(features) | |
| return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip() | |
| except Exception as e: | |
| print(f"Transcription error: {str(e)}") | |
| return "" | |
| def get_asr_pipeline(): | |
| """Lazy load ASR pipeline with proper configuration.""" | |
| global transcriber | |
| if "transcriber" not in globals(): | |
| transcriber = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base.en", | |
| chunk_length_s=30, | |
| stride_length_s=5, | |
| device="cpu", | |
| torch_dtype=torch.float32 | |
| ) | |
| return transcriber | |
| def process_speech(audio_data, symptom_index): | |
| """Process speech input and convert to text.""" | |
| if not audio_data: | |
| return [] | |
| if isinstance(audio_data, tuple) and len(audio_data) == 2: | |
| sample_rate, audio_array = audio_data | |
| # Audio preprocessing | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| audio_array = audio_array.astype(np.float32) | |
| audio_array /= np.max(np.abs(audio_array)) | |
| # Ensure correct sampling rate | |
| if sample_rate != 16000: | |
| resampler = T.Resample(sample_rate, 16000) | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| audio_tensor = resampler(audio_tensor) | |
| audio_array = audio_tensor.numpy() | |
| sample_rate = 16000 | |
| # Transcribe with error handling | |
| # Format dictionary correctly with required keys | |
| input_features = { | |
| "raw": audio_array, | |
| "sampling_rate": sample_rate | |
| } | |
| result = transcriber(input_features) | |
| # Handle different result types | |
| if isinstance(result, dict) and "text" in result: | |
| transcript = result["text"].strip() | |
| elif isinstance(result, str): | |
| transcript = result.strip() | |
| else: | |
| print(f"Unexpected transcriber result type: {type(result)}") | |
| return [] | |
| if not transcript: | |
| print("No transcription generated") | |
| return [] | |
| # Query symptoms with transcribed text | |
| diagnosis_query = f""" | |
| Given these symptoms: '{transcript}' | |
| Identify the most likely ICD-10 diagnoses and key questions. | |
| Focus on clinical implications. | |
| """ | |
| response = symptom_index.as_query_engine().query(diagnosis_query) | |
| return [ | |
| {"role": "user", "content": transcript}, | |
| {"role": "assistant", "content": json.dumps({ | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response) | |
| })} | |
| ] | |
| else: | |
| print(f"Invalid audio format: {type(audio_data)}") | |
| return [] | |
| def format_response_for_user(response_dict): | |
| """Format the assistant's response dictionary into a user-friendly string.""" | |
| diagnoses = response_dict.get("diagnoses", []) | |
| confidences = response_dict.get("confidences", []) | |
| follow_up = response_dict.get("follow_up", "") | |
| result = "" | |
| if diagnoses: | |
| result += "Possible Diagnoses:\n" | |
| for i, diag in enumerate(diagnoses): | |
| conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else "" | |
| result += f"- {diag}{conf}\n" | |
| if follow_up: | |
| result += f"\nFollow-up: {follow_up}" | |
| return result.strip() | |
| def enhanced_process_speech(audio_path, symptom_index, history, api_key=None, model_tier="small", temp=0.7): | |
| """Handle streaming speech processing and chat updates.""" | |
| transcriber = get_asr_pipeline() | |
| if not audio_path: | |
| return history | |
| if isinstance(audio_path, tuple) and len(audio_path) == 2: | |
| sample_rate, audio_array = audio_path | |
| # Audio preprocessing | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| audio_array = audio_array.astype(np.float32) | |
| audio_array /= np.max(np.abs(audio_array)) | |
| # Ensure correct sampling rate | |
| if sample_rate != 16000: | |
| resampler = T.Resample( | |
| orig_freq=sample_rate, | |
| new_freq=16000 | |
| ) | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| audio_tensor = resampler(audio_tensor) | |
| audio_array = audio_tensor.numpy() | |
| sample_rate = 16000 | |
| # Format input dictionary exactly as required | |
| transcriber_input = { | |
| "raw": audio_array, | |
| "sampling_rate": sample_rate | |
| } | |
| # Get transcription from Whisper | |
| result = transcriber(transcriber_input) | |
| # Extract text from result | |
| transcript = "" | |
| if isinstance(result, dict): | |
| transcript = result.get("text", "").strip() | |
| elif isinstance(result, str): | |
| transcript = result.strip() | |
| if not transcript: | |
| return history | |
| # Process the symptoms | |
| diagnosis_query = f""" | |
| Based on these symptoms: '{transcript}' | |
| Provide relevant ICD-10 codes and diagnostic questions. | |
| """ | |
| response = symptom_index.as_query_engine().query(diagnosis_query) | |
| # Format and return chat messages | |
| return history + [ | |
| {"role": "user", "content": transcript}, | |
| {"role": "assistant", "content": format_response_for_user({ | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response) | |
| })} | |
| ] | |