Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import torchaudio | |
| import torchvision | |
| import json | |
| # Add parent directory to path to import preprocess functions | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Import functions from preprocess and model definitions | |
| from preprocess import process_image_data | |
| from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES | |
| # Define the top-performing models based on evaluation | |
| TOP_MODELS = [ | |
| {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"}, | |
| {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"}, | |
| {"image_backbone": "resnet50", "audio_backbone": "transformer"} | |
| ] | |
| # Define the MoE Model | |
| class WatermelonMoEModel(torch.nn.Module): | |
| def __init__(self, model_configs, model_dir="models", weights=None): | |
| """ | |
| Mixture of Experts model that combines multiple backbone models. | |
| Args: | |
| model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys | |
| model_dir: Directory where model checkpoints are stored | |
| weights: Optional list of weights for each model (None for equal weighting) | |
| """ | |
| super(WatermelonMoEModel, self).__init__() | |
| self.models = [] | |
| self.model_configs = model_configs | |
| # Load each model | |
| for config in model_configs: | |
| img_backbone = config["image_backbone"] | |
| audio_backbone = config["audio_backbone"] | |
| # Initialize model | |
| model = WatermelonModelModular(img_backbone, audio_backbone) | |
| # Load weights | |
| model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt") | |
| if os.path.exists(model_path): | |
| print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}") | |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| else: | |
| print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}") | |
| continue | |
| model.eval() # Set to evaluation mode | |
| self.models.append(model) | |
| # Set model weights (uniform by default) | |
| if weights: | |
| assert len(weights) == len(self.models), "Number of weights must match number of models" | |
| self.weights = weights | |
| else: | |
| self.weights = [1.0 / len(self.models)] * len(self.models) if self.models else [1.0] | |
| print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble") | |
| print(f"\033[92mINFO\033[0m: Model weights: {self.weights}") | |
| def to(self, device): | |
| """ | |
| Override to() method to ensure all sub-models are moved to the same device | |
| """ | |
| for model in self.models: | |
| model.to(device) | |
| return super(WatermelonMoEModel, self).to(device) | |
| def forward(self, mfcc, image): | |
| """ | |
| Forward pass through the MoE model. | |
| Returns the weighted average of all model outputs. | |
| """ | |
| if not self.models: | |
| print(f"\033[91mERR!\033[0m: No models available for inference!") | |
| return torch.tensor([0.0], device=mfcc.device) | |
| outputs = [] | |
| # Get outputs from each model | |
| with torch.no_grad(): | |
| for i, model in enumerate(self.models): | |
| output = model(mfcc, image) | |
| # print the output value | |
| print(f"\033[92mDEBUG\033[0m: Model {i} output: {output}") | |
| outputs.append(output * self.weights[i]) | |
| # Return weighted average | |
| return torch.sum(torch.stack(outputs), dim=0) | |
| # Modified version of process_audio_data specifically for the app to handle various tensor shapes | |
| def app_process_audio_data(waveform, sample_rate): | |
| """Modified version of process_audio_data for the app that handles different tensor dimensions""" | |
| try: | |
| print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}") | |
| # Handle different tensor dimensions | |
| if waveform.dim() == 3: | |
| print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D") | |
| # For 3D tensor, take the first item (batch dimension) | |
| waveform = waveform[0] | |
| if waveform.dim() == 2: | |
| # Use the first channel for stereo audio | |
| waveform = waveform[0] | |
| print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}") | |
| # Resample to 16kHz if needed | |
| resample_rate = 16000 | |
| if sample_rate != resample_rate: | |
| print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz") | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
| # Ensure 3 seconds of audio | |
| if waveform.size(0) < 3 * resample_rate: | |
| print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples") | |
| waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0))) | |
| else: | |
| print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples") | |
| waveform = waveform[: 3 * resample_rate] | |
| # Apply MFCC transformation | |
| print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation") | |
| mfcc_transform = torchaudio.transforms.MFCC( | |
| sample_rate=resample_rate, | |
| n_mfcc=13, | |
| melkwargs={ | |
| "n_fft": 256, | |
| "win_length": 256, | |
| "hop_length": 128, | |
| "n_mels": 40, | |
| } | |
| ) | |
| mfcc = mfcc_transform(waveform) | |
| print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}") | |
| return mfcc | |
| except Exception as e: | |
| import traceback | |
| print(f"\033[91mERR!\033[0m: Error in audio processing: {e}") | |
| print(traceback.format_exc()) | |
| return None | |
| # Using the decorator for GPU acceleration | |
| def predict_sugar_content(audio, image, model_dir="models", weights=None): | |
| """Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model""" | |
| try: | |
| # Check CUDA availability inside the GPU-decorated function | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\033[92mINFO\033[0m: Using device: {device}") | |
| # Load MoE model | |
| moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights) | |
| moe_model = moe_model.to(device) # Move entire model to device | |
| moe_model.eval() | |
| print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models") | |
| # Handle different audio input formats | |
| if isinstance(audio, tuple) and len(audio) >= 2: | |
| sample_rate, audio_data = audio[0], audio[1] if len(audio) == 2 else audio[-1] | |
| elif isinstance(audio, str): | |
| audio_data, sample_rate = torchaudio.load(audio) | |
| else: | |
| return f"Error: Unsupported audio format. Got {type(audio)}" | |
| # Convert audio to tensor if needed | |
| if isinstance(audio_data, np.ndarray): | |
| audio_tensor = torch.tensor(audio_data).float() | |
| else: | |
| audio_tensor = audio_data.float() | |
| # Process audio | |
| mfcc = app_process_audio_data(audio_tensor, sample_rate) | |
| if mfcc is None: | |
| return "Error: Failed to process audio input" | |
| # Process image | |
| if isinstance(image, np.ndarray): | |
| image_tensor = torch.from_numpy(image).permute(2, 0, 1) # Convert to CxHxW format | |
| elif isinstance(image, str): | |
| image_tensor = torchvision.io.read_image(image) | |
| else: | |
| return f"Error: Unsupported image format. Got {type(image)}" | |
| image_tensor = image_tensor.float() | |
| processed_image = process_image_data(image_tensor) | |
| if processed_image is None: | |
| return "Error: Failed to process image input" | |
| # Add batch dimension and move to device | |
| mfcc = mfcc.unsqueeze(0).to(device) | |
| processed_image = processed_image.unsqueeze(0).to(device) | |
| # Run inference | |
| with torch.no_grad(): | |
| brix_value = moe_model(mfcc, processed_image) | |
| prediction = brix_value.item() | |
| print(f"\033[92mDEBUG\033[0m: Raw prediction: {prediction}") | |
| # Ensure prediction is within reasonable bounds (e.g., 6-13 Brix) | |
| prediction = max(6.0, min(13.0, prediction)) | |
| print(f"\033[92mDEBUG\033[0m: Bounded prediction: {prediction}") | |
| # Format the result | |
| result = f"🍉 Predicted Sugar Content: {prediction:.1f}° Brix 🍉\n\n" | |
| # Add extra info about the MoE model | |
| result += "Using Ensemble of Top-3 Models:\n" | |
| result += "- EfficientNet-B3 + Transformer\n" | |
| result += "- EfficientNet-B0 + Transformer\n" | |
| result += "- ResNet-50 + Transformer\n\n" | |
| # Add Brix scale visualization | |
| result += "Sugar Content Scale (in °Brix):\n" | |
| result += "──────────────────────────────────\n" | |
| # Create the scale display with Brix ranges | |
| scale_ranges = [ | |
| (0, 8, "Low Sugar (< 8° Brix)"), | |
| (8, 9, "Mild Sweetness (8-9° Brix)"), | |
| (9, 10, "Medium Sweetness (9-10° Brix)"), | |
| (10, 11, "Sweet (10-11° Brix)"), | |
| (11, 13, "Very Sweet (11-13° Brix)") | |
| ] | |
| # Find which category the prediction falls into | |
| user_category = None | |
| for min_val, max_val, category_name in scale_ranges: | |
| if min_val <= prediction < max_val: | |
| user_category = category_name | |
| break | |
| if prediction >= scale_ranges[-1][0]: # Handle edge case | |
| user_category = scale_ranges[-1][2] | |
| # Display the scale with the user's result highlighted | |
| for min_val, max_val, category_name in scale_ranges: | |
| if category_name == user_category: | |
| result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n" | |
| else: | |
| result += f" {min_val}-{max_val}: {category_name}\n" | |
| result += "──────────────────────────────────\n\n" | |
| # Add assessment of the watermelon's sugar content | |
| if prediction < 8: | |
| result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter." | |
| elif prediction < 9: | |
| result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet." | |
| elif prediction < 10: | |
| result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness." | |
| elif prediction < 11: | |
| result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy." | |
| else: | |
| result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor." | |
| return result | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error: {str(e)}\n\n" | |
| error_msg += traceback.format_exc() | |
| print(f"\033[91mERR!\033[0m: {error_msg}") | |
| return error_msg | |
| def create_app(model_dir="models", weights=None): | |
| """Create and launch the Gradio interface""" | |
| # Define the prediction function with model path | |
| def predict_fn(audio, image): | |
| return predict_sugar_content(audio, image, model_dir, weights) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)") | |
| gr.Markdown(""" | |
| This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance. | |
| ## What's New | |
| This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models: | |
| - EfficientNet-B3 + Transformer | |
| - EfficientNet-B0 + Transformer | |
| - ResNet-50 + Transformer | |
| The ensemble approach provides more accurate predictions than any single model! | |
| ## Instructions: | |
| 1. Upload or record an audio of tapping the watermelon | |
| 2. Upload or capture an image of the watermelon | |
| 3. Click 'Predict' to get the sugar content estimation | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Upload or Record Audio", type="numpy") | |
| image_input = gr.Image(label="Upload or Capture Image") | |
| submit_btn = gr.Button("Predict Sugar Content", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Prediction Results", lines=15) | |
| submit_btn.click( | |
| fn=predict_fn, | |
| inputs=[audio_input, image_input], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| ## Tips for best results | |
| - For audio: Tap the watermelon with your knuckle and record the sound | |
| - For image: Take a clear photo of the whole watermelon in good lighting | |
| ## About Brix Measurement | |
| Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit. | |
| The average ripe watermelon has a Brix value between 9-11°. | |
| ## About the Mixture of Experts Model | |
| This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks. | |
| Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly | |
| better than any individual model (best individual model: ~0.36 MAE). | |
| """) | |
| return interface | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)") | |
| parser.add_argument( | |
| "--model_dir", | |
| type=str, | |
| default="models", | |
| help="Directory containing the model checkpoints" | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| help="Create a shareable link for the app" | |
| ) | |
| parser.add_argument( | |
| "--debug", | |
| action="store_true", | |
| help="Enable verbose debug output" | |
| ) | |
| parser.add_argument( | |
| "--weighting", | |
| type=str, | |
| choices=["uniform", "performance"], | |
| default="uniform", | |
| help="How to weight the models (uniform or based on performance)" | |
| ) | |
| args = parser.parse_args() | |
| if args.debug: | |
| print(f"\033[92mINFO\033[0m: Debug mode enabled") | |
| # Check if model directory exists | |
| if not os.path.exists(args.model_dir): | |
| print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}") | |
| sys.exit(1) | |
| # Determine weights based on argument | |
| weights = None | |
| if args.weighting == "performance": | |
| # Weights inversely proportional to the MAE (better models get higher weights) | |
| # These are the MAE values from the evaluation results | |
| mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer | |
| # Convert to weights (inverse of MAE, normalized) | |
| inverse_mae = [1/mae for mae in mae_values] | |
| total = sum(inverse_mae) | |
| weights = [val/total for val in inverse_mae] | |
| print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}") | |
| else: | |
| print(f"\033[92mINFO\033[0m: Using uniform weights") | |
| # Create and launch the app | |
| app = create_app(args.model_dir, weights) | |
| app.launch(share=args.share) |