Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import torch | |
| import torchaudio | |
| import torchvision | |
| import argparse | |
| import numpy as np | |
| # Add parent directory to path to import the preprocess functions | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from preprocess import process_audio_data, process_image_data | |
| # Import the model definition | |
| from train_watermelon import WatermelonModel | |
| def load_model(model_path): | |
| """Load a trained model from the given path""" | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| print(f"\033[92mINFO\033[0m: Using device: {device}") | |
| model = WatermelonModel().to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| print(f"\033[92mINFO\033[0m: Loaded model from {model_path}") | |
| return model, device | |
| def infer_single_sample(audio_path, image_path, model, device): | |
| """Run inference on a single sample""" | |
| # Load and process audio | |
| try: | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| mfcc = process_audio_data(waveform, sample_rate).to(device) | |
| # Load and process image | |
| image = torchvision.io.read_image(image_path) | |
| image = image.float() | |
| processed_image = process_image_data(image).to(device) | |
| # Add batch dimension | |
| mfcc = mfcc.unsqueeze(0) | |
| processed_image = processed_image.unsqueeze(0) | |
| # Run inference | |
| with torch.no_grad(): | |
| sweetness = model(mfcc, processed_image) | |
| return sweetness.item() | |
| except Exception as e: | |
| print(f"\033[91mERR!\033[0m: Error in inference: {e}") | |
| return None | |
| def infer_from_directory(data_dir, model_path, output_file=None, num_samples=None): | |
| """Run inference on samples from the dataset directory""" | |
| # Load model | |
| model, device = load_model(model_path) | |
| # Collect all samples | |
| samples = [] | |
| results = [] | |
| print(f"\033[92mINFO\033[0m: Reading samples from {data_dir}") | |
| # Walk through the directory structure | |
| for sweetness_dir in os.listdir(data_dir): | |
| try: | |
| sweetness = float(sweetness_dir) | |
| sweetness_path = os.path.join(data_dir, sweetness_dir) | |
| if os.path.isdir(sweetness_path): | |
| for id_dir in os.listdir(sweetness_path): | |
| id_path = os.path.join(sweetness_path, id_dir) | |
| if os.path.isdir(id_path): | |
| audio_file = os.path.join(id_path, f"{id_dir}.wav") | |
| image_file = os.path.join(id_path, f"{id_dir}.jpg") | |
| if os.path.exists(audio_file) and os.path.exists(image_file): | |
| samples.append((audio_file, image_file, sweetness, id_dir)) | |
| except ValueError: | |
| # Skip directories that are not valid sweetness values | |
| continue | |
| # Limit the number of samples if specified | |
| if num_samples is not None and num_samples > 0: | |
| samples = samples[:num_samples] | |
| print(f"\033[92mINFO\033[0m: Running inference on {len(samples)} samples") | |
| # Run inference on each sample | |
| for i, (audio_file, image_file, true_sweetness, sample_id) in enumerate(samples): | |
| print(f"\033[92mINFO\033[0m: Processing sample {i+1}/{len(samples)}: {sample_id}") | |
| predicted_sweetness = infer_single_sample(audio_file, image_file, model, device) | |
| if predicted_sweetness is not None: | |
| error = abs(predicted_sweetness - true_sweetness) | |
| results.append({ | |
| 'sample_id': sample_id, | |
| 'true_sweetness': true_sweetness, | |
| 'predicted_sweetness': predicted_sweetness, | |
| 'error': error | |
| }) | |
| print(f" Sample ID: {sample_id}") | |
| print(f" True sweetness: {true_sweetness:.2f}") | |
| print(f" Predicted sweetness: {predicted_sweetness:.2f}") | |
| print(f" Error: {error:.2f}") | |
| # Calculate mean absolute error | |
| if results: | |
| mae = np.mean([result['error'] for result in results]) | |
| print(f"\033[92mINFO\033[0m: Mean Absolute Error: {mae:.4f}") | |
| # Save results to file if specified | |
| if output_file and results: | |
| with open(output_file, 'w') as f: | |
| f.write("sample_id,true_sweetness,predicted_sweetness,error\n") | |
| for result in results: | |
| f.write(f"{result['sample_id']},{result['true_sweetness']:.2f},{result['predicted_sweetness']:.2f},{result['error']:.2f}\n") | |
| print(f"\033[92mINFO\033[0m: Results saved to {output_file}") | |
| return results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Watermelon Sweetness Inference") | |
| parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model file") | |
| parser.add_argument("--data_dir", type=str, default="../cleaned", help="Path to the cleaned dataset directory") | |
| parser.add_argument("--output_file", type=str, help="Path to save inference results (CSV)") | |
| parser.add_argument("--num_samples", type=int, help="Number of samples to run inference on (default: all)") | |
| parser.add_argument("--audio_path", type=str, help="Path to a single audio file for inference") | |
| parser.add_argument("--image_path", type=str, help="Path to a single image file for inference") | |
| args = parser.parse_args() | |
| # Check if single sample inference or dataset inference | |
| if args.audio_path and args.image_path: | |
| # Single sample inference | |
| model, device = load_model(args.model_path) | |
| sweetness = infer_single_sample(args.audio_path, args.image_path, model, device) | |
| print(f"Predicted sweetness: {sweetness:.2f}") | |
| else: | |
| # Dataset inference | |
| infer_from_directory(args.data_dir, args.model_path, args.output_file, args.num_samples) | |
| if __name__ == "__main__": | |
| main() |