| import argparse | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| from download_ixi import DEFAULT_DESTINATION | |
| def evaluate_ixi_predictions( | |
| labels: str, | |
| predictions: str | |
| ) -> None: | |
| labels = pd.read_excel(labels) | |
| predictions = pd.read_csv(predictions) | |
| predictions['IXI_ID'] = predictions['source'].apply( | |
| lambda path: int(path.split('/')[-1][3:6]) | |
| ) | |
| predictions['age_prediction'] = predictions['age'] | |
| predictions = pd.merge( | |
| predictions[['IXI_ID', 'age_prediction']], | |
| labels[['IXI_ID', 'AGE']], | |
| on='IXI_ID', | |
| how='left' | |
| ) | |
| mae = np.mean(np.abs(predictions['AGE'] - predictions['age_prediction'])) | |
| print(f'MAE: {mae:.2f}') | |
| plt.scatter(predictions['AGE'], predictions['age_prediction']) | |
| plt.xlabel('True age') | |
| plt.ylabel('Predicted age') | |
| plt.title('Age prediction') | |
| plt.show() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser( | |
| 'Evaluates predictions for the IXI dataset' | |
| ) | |
| parser.add_argument( | |
| '-l', '--labels', | |
| required=False, | |
| default=os.path.join(DEFAULT_DESTINATION, 'IXI.xls'), | |
| help='Path to XLSX containing labels' | |
| ) | |
| parser.add_argument( | |
| '-p', '--predictions', | |
| required=False, | |
| default=os.path.join( | |
| DEFAULT_DESTINATION, | |
| 'outputs', | |
| 'predictions.csv' | |
| ), | |
| help='Path to CSV containing predictions' | |
| ) | |
| args = parser.parse_args() | |
| evaluate_ixi_predictions( | |
| labels=args.labels, | |
| predictions=args.predictions | |
| ) | |