Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import torchaudio | |
| import torchvision | |
| # Add parent directory to path to import preprocess functions | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Import functions from infer_watermelon.py | |
| from infer_watermelon import load_model | |
| # Modified version of process_audio_data specifically for the app to handle various tensor shapes | |
| def app_process_audio_data(waveform, sample_rate): | |
| """Modified version of process_audio_data for the app that handles different tensor dimensions""" | |
| try: | |
| print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}") | |
| # Handle different tensor dimensions | |
| if waveform.dim() == 3: | |
| print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D") | |
| # For 3D tensor, take the first item (batch dimension) | |
| waveform = waveform[0] | |
| if waveform.dim() == 2: | |
| # Use the first channel for stereo audio | |
| waveform = waveform[0] | |
| print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}") | |
| # Resample to 16kHz if needed | |
| resample_rate = 16000 | |
| if sample_rate != resample_rate: | |
| print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz") | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
| # Ensure 3 seconds of audio | |
| if waveform.size(0) < 3 * resample_rate: | |
| print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples") | |
| waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0))) | |
| else: | |
| print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples") | |
| waveform = waveform[: 3 * resample_rate] | |
| # Apply MFCC transformation | |
| print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation") | |
| mfcc_transform = torchaudio.transforms.MFCC( | |
| sample_rate=resample_rate, | |
| n_mfcc=13, | |
| melkwargs={ | |
| "n_fft": 256, | |
| "win_length": 256, | |
| "hop_length": 128, | |
| "n_mels": 40, | |
| } | |
| ) | |
| mfcc = mfcc_transform(waveform) | |
| print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}") | |
| return mfcc | |
| except Exception as e: | |
| import traceback | |
| print(f"\033[91mERR!\033[0m: Error in audio processing: {e}") | |
| print(traceback.format_exc()) | |
| return None | |
| # Similarly for images, but let's import the original one | |
| from preprocess import process_image_data | |
| def init_model(model_path): | |
| """Initialize the model for inference""" | |
| model, device = load_model(model_path) | |
| return model, device | |
| def predict_sweetness(audio, image, model, device): | |
| """Predict sweetness of a watermelon from audio and image input""" | |
| try: | |
| # Debug information about input types | |
| print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}") | |
| print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}") | |
| print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}") | |
| if isinstance(image, np.ndarray): | |
| print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}") | |
| # Handle different audio input formats | |
| if isinstance(audio, tuple) and len(audio) == 2: | |
| # Standard Gradio format: (sample_rate, audio_data) | |
| sample_rate, audio_data = audio | |
| print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}") | |
| print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}") | |
| elif isinstance(audio, tuple) and len(audio) > 2: | |
| # Sometimes Gradio returns (sample_rate, audio_data, other_info...) | |
| sample_rate, audio_data = audio[0], audio[-1] | |
| print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}") | |
| print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}") | |
| elif isinstance(audio, str): | |
| # Direct path to audio file | |
| import torchaudio | |
| audio_data, sample_rate = torchaudio.load(audio) | |
| print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}") | |
| else: | |
| return f"Error: Unsupported audio format. Got {type(audio)}" | |
| # Create a temporary file path for the audio and image | |
| temp_dir = "temp" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_audio_path = os.path.join(temp_dir, "temp_audio.wav") | |
| temp_image_path = os.path.join(temp_dir, "temp_image.jpg") | |
| # Import necessary libraries | |
| import torchaudio | |
| import torchvision | |
| import torchvision.transforms.functional as F | |
| from PIL import Image | |
| # Audio handling - direct processing from the data in memory | |
| if isinstance(audio_data, np.ndarray): | |
| # Convert numpy array to tensor | |
| print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor") | |
| audio_tensor = torch.tensor(audio_data).float() | |
| # Handle different audio dimensions | |
| if audio_data.ndim == 1: | |
| # Single channel audio | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| elif audio_data.ndim == 2: | |
| # Ensure channels are first dimension | |
| if audio_data.shape[0] > audio_data.shape[1]: | |
| # More rows than columns, probably (samples, channels) | |
| audio_tensor = torch.tensor(audio_data.T).float() | |
| else: | |
| # Already a tensor | |
| audio_tensor = audio_data.float() | |
| print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}") | |
| # Skip saving/loading and process directly | |
| mfcc = app_process_audio_data(audio_tensor, sample_rate) | |
| print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}") | |
| # Image handling | |
| if isinstance(image, np.ndarray): | |
| print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL") | |
| pil_image = Image.fromarray(image) | |
| pil_image.save(temp_image_path) | |
| print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}") | |
| elif isinstance(image, str): | |
| # If image is already a path | |
| temp_image_path = image | |
| print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}") | |
| else: | |
| return f"Error: Unsupported image format. Got {type(image)}" | |
| # Process image | |
| print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}") | |
| image_tensor = torchvision.io.read_image(temp_image_path) | |
| print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}") | |
| image_tensor = image_tensor.float() | |
| processed_image = process_image_data(image_tensor) | |
| print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}") | |
| # Add batch dimension for inference | |
| if mfcc is not None: | |
| mfcc = mfcc.unsqueeze(0).to(device) | |
| print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}") | |
| if processed_image is not None: | |
| processed_image = processed_image.unsqueeze(0).to(device) | |
| print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}") | |
| # Run inference | |
| print(f"\033[92mDEBUG\033[0m: Running inference") | |
| if mfcc is not None and processed_image is not None: | |
| with torch.no_grad(): | |
| sweetness = model(mfcc, processed_image) | |
| print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}") | |
| else: | |
| return "Error: Failed to process inputs. Please check the debug logs." | |
| # Format the result | |
| if sweetness is not None: | |
| result = f"Predicted Sweetness: {sweetness.item():.2f}/13" | |
| # Add a qualitative description | |
| if sweetness.item() < 9: | |
| result += "\n\nThis watermelon is not very sweet. You might want to choose another one." | |
| elif sweetness.item() < 10: | |
| result += "\n\nThis watermelon has moderate sweetness." | |
| elif sweetness.item() < 11: | |
| result += "\n\nThis watermelon is sweet! A good choice." | |
| else: | |
| result += "\n\nThis watermelon is very sweet! Excellent choice!" | |
| return result | |
| else: | |
| return "Error: Could not predict sweetness. Please try again with different inputs." | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error: {str(e)}\n\n" | |
| error_msg += traceback.format_exc() | |
| print(f"\033[91mERR!\033[0m: {error_msg}") | |
| return error_msg | |
| def create_app(model_path): | |
| """Create and launch the Gradio interface""" | |
| # Initialize model | |
| model, device = init_model(model_path) | |
| # Define the prediction function with model and device | |
| def predict_fn(audio, image): | |
| return predict_sweetness(audio, image, model, device) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Watermelon Sweetness Predictor") as interface: | |
| gr.Markdown("# 🍉 Watermelon Sweetness Predictor") | |
| gr.Markdown(""" | |
| This app predicts the sweetness of a watermelon based on its sound and appearance. | |
| ## Instructions: | |
| 1. Upload or record an audio of tapping the watermelon | |
| 2. Upload or capture an image of the watermelon | |
| 3. Click 'Submit' to get the predicted sweetness | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Upload or Record Audio", type="numpy") | |
| image_input = gr.Image(label="Upload or Capture Image") | |
| submit_btn = gr.Button("Predict Sweetness", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Prediction Results", lines=6) | |
| submit_btn.click( | |
| fn=predict_fn, | |
| inputs=[audio_input, image_input], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| ## How it works | |
| The app uses a deep learning model that combines: | |
| - Audio analysis using MFCC features and LSTM neural network | |
| - Image analysis using ResNet-50 convolutional neural network | |
| The model was trained on a dataset of watermelons with known sweetness values. | |
| """) | |
| return interface | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Watermelon Sweetness Prediction App") | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| default="models/watermelon_model_final.pt", | |
| help="Path to the trained model file" | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| help="Create a shareable link for the app" | |
| ) | |
| parser.add_argument( | |
| "--debug", | |
| action="store_true", | |
| help="Enable verbose debug output" | |
| ) | |
| args = parser.parse_args() | |
| if args.debug: | |
| print(f"\033[92mINFO\033[0m: Debug mode enabled") | |
| # Check if model exists | |
| if not os.path.exists(args.model_path): | |
| print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}") | |
| print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path") | |
| sys.exit(1) | |
| # Create and launch the app | |
| app = create_app(args.model_path) | |
| app.launch(share=args.share) |