Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchaudio | |
| import torchvision | |
| import numpy as np | |
| import json | |
| from torch.utils.data import Dataset, DataLoader | |
| import sys | |
| from tqdm import tqdm | |
| # Add parent directory to path to import the preprocess functions | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from preprocess import process_audio_data, process_image_data | |
| # Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file | |
| from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES | |
| # Print library versions | |
| print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") | |
| print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") | |
| print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") | |
| # Device selection | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| print(f"\033[92mINFO\033[0m: Using device: {device}") | |
| # Define the top-performing models based on the previous evaluation | |
| TOP_MODELS = [ | |
| {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"}, | |
| {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"}, | |
| {"image_backbone": "resnet50", "audio_backbone": "transformer"} | |
| ] | |
| # Define class for the MoE model | |
| class WatermelonMoEModel(torch.nn.Module): | |
| def __init__(self, model_configs, model_dir="models", weights=None): | |
| """ | |
| Mixture of Experts model that combines multiple backbone models. | |
| Args: | |
| model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys | |
| model_dir: Directory where model checkpoints are stored | |
| weights: Optional list of weights for each model (None for equal weighting) | |
| """ | |
| super(WatermelonMoEModel, self).__init__() | |
| self.models = [] | |
| self.model_configs = model_configs | |
| # Load each model | |
| for config in model_configs: | |
| img_backbone = config["image_backbone"] | |
| audio_backbone = config["audio_backbone"] | |
| # Initialize model | |
| model = WatermelonModelModular(img_backbone, audio_backbone) | |
| # Load weights | |
| model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt") | |
| if os.path.exists(model_path): | |
| print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}") | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| else: | |
| print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}") | |
| continue | |
| model.to(device) | |
| model.eval() # Set to evaluation mode | |
| self.models.append(model) | |
| # Set model weights (uniform by default) | |
| if weights: | |
| assert len(weights) == len(self.models), "Number of weights must match number of models" | |
| self.weights = weights | |
| else: | |
| self.weights = [1.0 / len(self.models)] * len(self.models) | |
| print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble") | |
| print(f"\033[92mINFO\033[0m: Model weights: {self.weights}") | |
| def forward(self, mfcc, image): | |
| """ | |
| Forward pass through the MoE model. | |
| Returns the weighted average of all model outputs. | |
| """ | |
| outputs = [] | |
| # Get outputs from each model | |
| with torch.no_grad(): | |
| for i, model in enumerate(self.models): | |
| output = model(mfcc, image) | |
| print(f"DEBUG: Model {i} output: {output}") | |
| outputs.append(output * self.weights[i]) | |
| # Return weighted average | |
| final_output = torch.sum(torch.stack(outputs), dim=0) | |
| print(f"DEBUG: Raw prediction: {final_output}") | |
| return final_output | |
| def evaluate_moe_model(data_dir, model_dir="models", weights=None): | |
| """ | |
| Evaluate the MoE model on the test set. | |
| """ | |
| # Load dataset | |
| print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}") | |
| dataset = WatermelonDataset(data_dir) | |
| n_samples = len(dataset) | |
| # Split dataset | |
| train_size = int(0.7 * n_samples) | |
| val_size = int(0.2 * n_samples) | |
| test_size = n_samples - train_size - val_size | |
| _, _, test_dataset = torch.utils.data.random_split( | |
| dataset, [train_size, val_size, test_size] | |
| ) | |
| # Use a reasonable batch size | |
| batch_size = 8 | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| # Initialize MoE model | |
| moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights) | |
| moe_model.eval() | |
| # Evaluation metrics | |
| mae_criterion = torch.nn.L1Loss() | |
| mse_criterion = torch.nn.MSELoss() | |
| test_mae = 0.0 | |
| test_mse = 0.0 | |
| print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples") | |
| # Individual model predictions for analysis | |
| individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": [] | |
| for config in TOP_MODELS} | |
| true_labels = [] | |
| moe_predictions = [] | |
| # Evaluation loop | |
| test_iterator = tqdm(test_loader, desc="Testing MoE") | |
| with torch.no_grad(): | |
| for i, (mfcc, image, label) in enumerate(test_iterator): | |
| try: | |
| mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
| # Store individual model outputs for analysis | |
| for j, model in enumerate(moe_model.models): | |
| config = TOP_MODELS[j] | |
| model_name = f"{config['image_backbone']}_{config['audio_backbone']}" | |
| output = model(mfcc, image) | |
| individual_predictions[model_name].extend(output.view(-1).cpu().numpy()) | |
| print(f"DEBUG: Model {j} output: {output}") | |
| # Get MoE prediction | |
| output = moe_model(mfcc, image) | |
| moe_predictions.extend(output.view(-1).cpu().numpy()) | |
| print(f"DEBUG: MoE prediction: {output}") | |
| # Store true labels | |
| label = label.view(-1, 1).float() | |
| true_labels.extend(label.view(-1).cpu().numpy()) | |
| # Calculate metrics | |
| mae = mae_criterion(output, label) | |
| mse = mse_criterion(output, label) | |
| test_mae += mae.item() | |
| test_mse += mse.item() | |
| test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"}) | |
| # Clean up memory | |
| if device.type == 'cuda': | |
| del mfcc, image, label, output, mae, mse | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}") | |
| if device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| continue | |
| # Calculate average metrics | |
| avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf') | |
| avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf') | |
| print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===") | |
| print(f"Test MAE: {avg_test_mae:.4f}") | |
| print(f"Test MSE: {avg_test_mse:.4f}") | |
| # Compare with individual models | |
| print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===") | |
| print(f"{'Model':<30} {'Test MAE':<15}") | |
| print("="*45) | |
| # Load previous results | |
| results_file = "backbone_evaluation_results.json" | |
| if os.path.exists(results_file): | |
| with open(results_file, 'r') as f: | |
| previous_results = json.load(f) | |
| # Filter results for our top models | |
| for config in TOP_MODELS: | |
| img_backbone = config["image_backbone"] | |
| audio_backbone = config["audio_backbone"] | |
| for result in previous_results: | |
| if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone: | |
| print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}") | |
| print(f"MoE (Ensemble) {avg_test_mae:<15.4f}") | |
| # Save results and predictions | |
| results = { | |
| "moe_test_mae": float(avg_test_mae), | |
| "moe_test_mse": float(avg_test_mse), | |
| "true_labels": [float(x) for x in true_labels], | |
| "moe_predictions": [float(x) for x in moe_predictions], | |
| "individual_predictions": {key: [float(x) for x in values] | |
| for key, values in individual_predictions.items()} | |
| } | |
| with open("moe_evaluation_results.json", 'w') as f: | |
| json.dump(results, f, indent=4) | |
| print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json") | |
| return avg_test_mae, avg_test_mse | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction") | |
| parser.add_argument( | |
| "--data_dir", | |
| type=str, | |
| default="../cleaned", | |
| help="Path to the cleaned dataset directory" | |
| ) | |
| parser.add_argument( | |
| "--model_dir", | |
| type=str, | |
| default="models", | |
| help="Directory containing model checkpoints" | |
| ) | |
| parser.add_argument( | |
| "--weighting", | |
| type=str, | |
| choices=["uniform", "performance"], | |
| default="uniform", | |
| help="How to weight the models (uniform or based on performance)" | |
| ) | |
| args = parser.parse_args() | |
| # Determine weights based on argument | |
| weights = None | |
| if args.weighting == "performance": | |
| # Weights inversely proportional to the MAE (better models get higher weights) | |
| # These are the MAE values from the provided results | |
| mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer | |
| # Convert to weights (inverse of MAE, normalized) | |
| inverse_mae = [1/mae for mae in mae_values] | |
| total = sum(inverse_mae) | |
| weights = [val/total for val in inverse_mae] | |
| print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}") | |
| else: | |
| print(f"\033[92mINFO\033[0m: Using uniform weights") | |
| # Evaluate the MoE model | |
| evaluate_moe_model(args.data_dir, args.model_dir, weights) |