sam-packer commited on
Commit ·
ea06981
1
Parent(s): 56785a6
Initial commit
Browse files- .gitignore +73 -0
- README.md +586 -1
- pyproject.toml +49 -0
- src/brewmatch/__init__.py +3 -0
- src/brewmatch/api/__init__.py +26 -0
- src/brewmatch/api/app.py +343 -0
- src/brewmatch/api/schemas.py +197 -0
- src/brewmatch/config.py +73 -0
- src/brewmatch/data/__init__.py +13 -0
- src/brewmatch/data/dataset.py +204 -0
- src/brewmatch/data/download.py +88 -0
- src/brewmatch/data/preprocess.py +311 -0
- src/brewmatch/device.py +74 -0
- src/brewmatch/evaluate.py +222 -0
- src/brewmatch/evaluation/__init__.py +33 -0
- src/brewmatch/evaluation/error_analysis.py +586 -0
- src/brewmatch/evaluation/metrics.py +401 -0
- src/brewmatch/experiment.py +492 -0
- src/brewmatch/models/__init__.py +13 -0
- src/brewmatch/models/base.py +142 -0
- src/brewmatch/models/baseline.py +167 -0
- src/brewmatch/models/classical.py +212 -0
- src/brewmatch/models/neural.py +431 -0
- src/brewmatch/train.py +479 -0
- src/brewmatch/tuning.py +550 -0
- src/brewmatch/utils.py +181 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.venv/
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
|
| 28 |
+
# uv
|
| 29 |
+
.python-version
|
| 30 |
+
|
| 31 |
+
# IDE
|
| 32 |
+
.idea/
|
| 33 |
+
.vscode/
|
| 34 |
+
*.swp
|
| 35 |
+
*.swo
|
| 36 |
+
*~
|
| 37 |
+
|
| 38 |
+
# Jupyter
|
| 39 |
+
.ipynb_checkpoints/
|
| 40 |
+
|
| 41 |
+
# Data (downloaded and processed)
|
| 42 |
+
data/raw/
|
| 43 |
+
data/processed/
|
| 44 |
+
|
| 45 |
+
# Model checkpoints
|
| 46 |
+
models/checkpoints/
|
| 47 |
+
|
| 48 |
+
# Experiment outputs
|
| 49 |
+
experiments/
|
| 50 |
+
|
| 51 |
+
# OS
|
| 52 |
+
.DS_Store
|
| 53 |
+
Thumbs.db
|
| 54 |
+
|
| 55 |
+
# Logs
|
| 56 |
+
*.log
|
| 57 |
+
logs/
|
| 58 |
+
|
| 59 |
+
# Environment variables / secrets
|
| 60 |
+
.env
|
| 61 |
+
.env.*
|
| 62 |
+
*.pem
|
| 63 |
+
kaggle.json
|
| 64 |
+
|
| 65 |
+
# Testing
|
| 66 |
+
.pytest_cache/
|
| 67 |
+
.coverage
|
| 68 |
+
htmlcov/
|
| 69 |
+
.tox/
|
| 70 |
+
.nox/
|
| 71 |
+
|
| 72 |
+
# mypy
|
| 73 |
+
.mypy_cache/
|
README.md
CHANGED
|
@@ -1 +1,586 @@
|
|
| 1 |
-
# BrewMatch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BrewMatch
|
| 2 |
+
|
| 3 |
+
A machine learning-powered coffee recommendation system that matches users with coffee beans based on their taste
|
| 4 |
+
preferences. Built for the Computer Vision module project, this system implements three distinct modeling approaches and
|
| 5 |
+
includes a production-ready Flask API.
|
| 6 |
+
|
| 7 |
+
## Table of Contents
|
| 8 |
+
|
| 9 |
+
- [Overview](#overview)
|
| 10 |
+
- [Installation](#installation)
|
| 11 |
+
- [Quick Start](#quick-start)
|
| 12 |
+
- [Project Structure](#project-structure)
|
| 13 |
+
- [Data Pipeline](#data-pipeline)
|
| 14 |
+
- [Models](#models)
|
| 15 |
+
- [Evaluation](#evaluation)
|
| 16 |
+
- [Experiment: Sensitivity Analysis](#experiment-sensitivity-analysis)
|
| 17 |
+
- [API Reference](#api-reference)
|
| 18 |
+
- [Deployment](#deployment)
|
| 19 |
+
|
| 20 |
+
## Overview
|
| 21 |
+
|
| 22 |
+
BrewMatch recommends coffee beans by learning taste profile similarities from the Coffee Quality Institute (CQI)dataset.
|
| 23 |
+
Given a user's preferred taste characteristics (aroma, flavor, acidity, body, etc.), the system finds coffees with
|
| 24 |
+
matching profiles.
|
| 25 |
+
|
| 26 |
+
### Key Features
|
| 27 |
+
|
| 28 |
+
- **Three modeling approaches**: Naive baseline, classical ML (KNN), and deep learning (neural embeddings)
|
| 29 |
+
- **Comprehensive evaluation**: Precision@K, Recall@K, NDCG@K, MSE, MAE
|
| 30 |
+
- **Error analysis**: Identifies mispredictions, patterns, and mitigation strategies
|
| 31 |
+
- **Sensitivity analysis experiment**: Measures performance vs. training set size
|
| 32 |
+
- **Production-ready API**: Flask REST API with validation and error handling
|
| 33 |
+
|
| 34 |
+
### Taste Profile Features
|
| 35 |
+
|
| 36 |
+
The system uses 9 sensory evaluation scores (0-10 scale):
|
| 37 |
+
|
| 38 |
+
| Feature | Description |
|
| 39 |
+
|------------|--------------------------------------------------------|
|
| 40 |
+
| Aroma | Scent/fragrance of the coffee |
|
| 41 |
+
| Flavor | Overall taste including sweetness, bitterness, acidity |
|
| 42 |
+
| Aftertaste | Lingering taste after swallowing |
|
| 43 |
+
| Acidity | Brightness and liveliness of taste |
|
| 44 |
+
| Body | Thickness/viscosity of the coffee |
|
| 45 |
+
| Balance | How well flavor components work together |
|
| 46 |
+
| Uniformity | Consistency from cup to cup |
|
| 47 |
+
| Clean Cup | Absence of off-flavors or defects |
|
| 48 |
+
| Sweetness | Caramel-like, fruity, or floral notes |
|
| 49 |
+
|
| 50 |
+
## Installation
|
| 51 |
+
|
| 52 |
+
### Prerequisites
|
| 53 |
+
|
| 54 |
+
- Python 3.13+
|
| 55 |
+
- [uv](https://docs.astral.sh/uv/) package manager
|
| 56 |
+
- GPU (optional): NVIDIA CUDA or Apple Silicon MPS for faster training
|
| 57 |
+
- Kaggle account (for dataset download)
|
| 58 |
+
|
| 59 |
+
### Setup
|
| 60 |
+
|
| 61 |
+
1. **Clone the repository**
|
| 62 |
+
```bash
|
| 63 |
+
git clone https://github.com/MrinalGoel643/BrewMatch.git
|
| 64 |
+
cd BrewMatch
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
2. **Install dependencies**
|
| 68 |
+
```bash
|
| 69 |
+
# CPU-only or Apple Silicon (MPS)
|
| 70 |
+
uv sync
|
| 71 |
+
|
| 72 |
+
# With NVIDIA CUDA support
|
| 73 |
+
uv sync --extra cuda
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
3. **Configure Kaggle credentials**
|
| 77 |
+
|
| 78 |
+
Create `~/.kaggle/kaggle.json` with your API credentials:
|
| 79 |
+
```json
|
| 80 |
+
{"username": "your_username", "key": "your_api_key"}
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Get your API key from [Kaggle Account Settings](https://www.kaggle.com/settings/account).
|
| 84 |
+
|
| 85 |
+
## Quick Start
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
# 1. Download the CQI coffee dataset
|
| 89 |
+
uv run download
|
| 90 |
+
|
| 91 |
+
# 2. Preprocess the data
|
| 92 |
+
uv run preprocess
|
| 93 |
+
|
| 94 |
+
# 3. Train all models (with default hyperparameters)
|
| 95 |
+
uv run train
|
| 96 |
+
|
| 97 |
+
# 4. OR: Tune hyperparameters first, then train (recommended)
|
| 98 |
+
uv run train --tune
|
| 99 |
+
|
| 100 |
+
# 5. Evaluate model performance
|
| 101 |
+
uv run evaluate --error-analysis
|
| 102 |
+
|
| 103 |
+
# 6. Start the API server
|
| 104 |
+
uv run serve
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Project Structure
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
brewmatch/
|
| 111 |
+
├── pyproject.toml # Project config and dependencies
|
| 112 |
+
├── README.md # This file
|
| 113 |
+
├── data/
|
| 114 |
+
│ ├── raw/ # Downloaded CSV files
|
| 115 |
+
│ └── processed/ # Train/test parquet + scaler
|
| 116 |
+
├── models/
|
| 117 |
+
│ └── checkpoints/ # Saved model files
|
| 118 |
+
├── experiments/ # Experiment results and plots
|
| 119 |
+
└── src/brewmatch/
|
| 120 |
+
├── __init__.py
|
| 121 |
+
├── config.py # Configuration settings
|
| 122 |
+
├── device.py # Device detection (CUDA/MPS/CPU)
|
| 123 |
+
├── utils.py # Utility functions
|
| 124 |
+
├── train.py # Training script (includes Optuna tuning)
|
| 125 |
+
├── evaluate.py # Evaluation script
|
| 126 |
+
├── experiment.py # Sensitivity analysis
|
| 127 |
+
├── data/
|
| 128 |
+
│ ├── __init__.py
|
| 129 |
+
│ ├── download.py # Kaggle dataset downloader
|
| 130 |
+
│ ├── preprocess.py # Data cleaning and splitting
|
| 131 |
+
│ └── dataset.py # PyTorch Dataset classes
|
| 132 |
+
├── models/
|
| 133 |
+
│ ├── __init__.py
|
| 134 |
+
│ ├── base.py # Abstract base class
|
| 135 |
+
│ ├── baseline.py # Naive baseline recommender
|
| 136 |
+
│ ├── classical.py # KNN/cosine similarity
|
| 137 |
+
│ └── neural.py # Neural embedding model
|
| 138 |
+
├── evaluation/
|
| 139 |
+
│ ├── __init__.py
|
| 140 |
+
│ ├── metrics.py # Ranking and regression metrics
|
| 141 |
+
│ └── error_analysis.py # Error pattern detection
|
| 142 |
+
└── api/
|
| 143 |
+
├── __init__.py
|
| 144 |
+
├── app.py # Flask application
|
| 145 |
+
└── schemas.py # Request validation
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Data Pipeline
|
| 149 |
+
|
| 150 |
+
### Download Dataset
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
uv run download [--force]
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Downloads the [CQI Coffee Quality Database](https://www.kaggle.com/datasets/volpatto/coffee-quality-database-from-cqi) from Kaggle
|
| 157 |
+
to `data/raw/`. This dataset contains ~1,340 coffee samples (Arabica + Robusta) with sensory evaluations.
|
| 158 |
+
|
| 159 |
+
| Option | Description |
|
| 160 |
+
|-----------|---------------------------------|
|
| 161 |
+
| `--force` | Re-download even if data exists |
|
| 162 |
+
|
| 163 |
+
### Preprocess Data
|
| 164 |
+
|
| 165 |
+
```bash
|
| 166 |
+
uv run preprocess [--test-size 0.2] [--seed 42]
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
Processes raw data and creates train/test splits:
|
| 170 |
+
|
| 171 |
+
1. Loads CSV files from `data/raw/`
|
| 172 |
+
2. Selects taste features and metadata columns
|
| 173 |
+
3. Drops rows with missing quality scores
|
| 174 |
+
4. Normalizes features using StandardScaler (fit on train only)
|
| 175 |
+
5. Splits data 80/20 train/test
|
| 176 |
+
6. Saves to `data/processed/`:
|
| 177 |
+
- `train.parquet` - Training data
|
| 178 |
+
- `test.parquet` - Test data
|
| 179 |
+
- `scaler.pkl` - Fitted scaler for inference
|
| 180 |
+
|
| 181 |
+
| Option | Description |
|
| 182 |
+
|---------------|-----------------------------------------------|
|
| 183 |
+
| `--test-size` | Fraction for test set (default: 0.2) |
|
| 184 |
+
| `--seed` | Random seed for reproducibility (default: 42) |
|
| 185 |
+
|
| 186 |
+
## Models
|
| 187 |
+
|
| 188 |
+
### 1. Naive Baseline (`NaiveBaselineRecommender`)
|
| 189 |
+
|
| 190 |
+
Establishes a performance floor using simple heuristics.
|
| 191 |
+
|
| 192 |
+
**Strategies:**
|
| 193 |
+
|
| 194 |
+
- `mean`: Recommends coffees closest to the global mean taste profile (ignores user preferences)
|
| 195 |
+
- `weighted_random`: Random sampling weighted by Total Cup Points
|
| 196 |
+
|
| 197 |
+
**When to use:** Sanity check; any useful model should beat this.
|
| 198 |
+
|
| 199 |
+
### 2. Classical ML (`ClassicalMLRecommender`)
|
| 200 |
+
|
| 201 |
+
Uses traditional similarity-based methods.
|
| 202 |
+
|
| 203 |
+
**Methods:**
|
| 204 |
+
|
| 205 |
+
- `knn`: K-Nearest Neighbors with Euclidean distance (sklearn NearestNeighbors)
|
| 206 |
+
- `cosine`: Cosine similarity ranking
|
| 207 |
+
|
| 208 |
+
**Features:**
|
| 209 |
+
|
| 210 |
+
- Optional feature normalization via StandardScaler
|
| 211 |
+
- Configurable number of neighbors
|
| 212 |
+
|
| 213 |
+
**When to use:** Fast inference, interpretable results, works well with small datasets.
|
| 214 |
+
|
| 215 |
+
### 3. Neural Network (`NeuralRecommender`)
|
| 216 |
+
|
| 217 |
+
Learns taste embeddings via contrastive learning.
|
| 218 |
+
|
| 219 |
+
**Architecture:**
|
| 220 |
+
|
| 221 |
+
- MLP encoder with residual connections
|
| 222 |
+
- Maps 9 taste features to 32-dimensional embedding space
|
| 223 |
+
- L2-normalized embeddings for cosine similarity
|
| 224 |
+
|
| 225 |
+
**Training:**
|
| 226 |
+
|
| 227 |
+
- Triplet loss with margin
|
| 228 |
+
- AdamW optimizer with cosine annealing
|
| 229 |
+
- Automatic positive/negative mining based on taste distance
|
| 230 |
+
|
| 231 |
+
**When to use:** Best performance with sufficient data; captures non-linear relationships.
|
| 232 |
+
|
| 233 |
+
### Training Models
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
uv run train [--models baseline classical neural] [--device cuda]
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
| Option | Description |
|
| 240 |
+
|------------|--------------------------------------------------------------|
|
| 241 |
+
| `--models` | Models to train: `baseline`, `classical`, `neural`, or `all` |
|
| 242 |
+
| `--device` | PyTorch device: `cuda` or `cpu` (auto-detected) |
|
| 243 |
+
|
| 244 |
+
Models are saved to `models/checkpoints/`:
|
| 245 |
+
|
| 246 |
+
- `baseline.pkl` - Pickled baseline model
|
| 247 |
+
- `classical.pkl` - Pickled KNN model
|
| 248 |
+
- `neural.pt` - PyTorch neural model
|
| 249 |
+
|
| 250 |
+
## Hyperparameter Tuning
|
| 251 |
+
|
| 252 |
+
BrewMatch includes automated hyperparameter optimization using [Optuna](https://optuna.org/), a Bayesian optimization framework with tree-structured Parzen estimators (TPE). Tuning is integrated into the training script.
|
| 253 |
+
|
| 254 |
+
### Training Workflow
|
| 255 |
+
|
| 256 |
+
```bash
|
| 257 |
+
# First run: uses default hyperparameters
|
| 258 |
+
uv run train
|
| 259 |
+
|
| 260 |
+
# Run with Optuna tuning (saves best params for future runs)
|
| 261 |
+
uv run train --tune
|
| 262 |
+
|
| 263 |
+
# Subsequent runs: automatically uses previously tuned hyperparameters
|
| 264 |
+
uv run train
|
| 265 |
+
|
| 266 |
+
# Re-tune anytime with --tune flag
|
| 267 |
+
uv run train --tune --neural-trials 100
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
| Option | Description |
|
| 271 |
+
|---------------------|------------------------------------------------------------|
|
| 272 |
+
| `--tune` | Run Optuna tuning before training |
|
| 273 |
+
| `--models` | Models to train/tune: `baseline`, `classical`, `neural`, `all` |
|
| 274 |
+
| `--neural-trials` | Number of Optuna trials for neural network (default: 50) |
|
| 275 |
+
| `--classical-trials`| Number of Optuna trials for classical ML (default: 30) |
|
| 276 |
+
| `--cv-folds` | Cross-validation folds for tuning (default: 3) |
|
| 277 |
+
| `--device` | PyTorch device: `cuda`, `mps`, or `cpu` (auto-detected) |
|
| 278 |
+
|
| 279 |
+
### Tuned Hyperparameters
|
| 280 |
+
|
| 281 |
+
**Neural Network:**
|
| 282 |
+
- `embedding_dim`: Embedding space dimension (16-128)
|
| 283 |
+
- `hidden_dim`: Hidden layer size (32-256)
|
| 284 |
+
- `learning_rate`: Adam learning rate (1e-4 to 1e-2, log scale)
|
| 285 |
+
- `margin`: Triplet loss margin (0.1-1.0)
|
| 286 |
+
- `batch_size`: Training batch size (16, 32, 64, 128)
|
| 287 |
+
|
| 288 |
+
**Classical ML:**
|
| 289 |
+
- `method`: Similarity method (`knn` or `cosine`)
|
| 290 |
+
- `n_neighbors`: Number of neighbors for KNN (5-100)
|
| 291 |
+
- `normalize`: Feature normalization (True/False)
|
| 292 |
+
|
| 293 |
+
### Outputs
|
| 294 |
+
|
| 295 |
+
Tuned hyperparameters are saved to `models/checkpoints/hyperparameters.json` and automatically loaded on subsequent training runs
|
| 296 |
+
|
| 297 |
+
## Evaluation
|
| 298 |
+
|
| 299 |
+
### Metrics
|
| 300 |
+
|
| 301 |
+
| Metric | Description |
|
| 302 |
+
|-----------------|----------------------------------------------------------------------|
|
| 303 |
+
| **Precision@K** | Proportion of top-K recommendations that are relevant |
|
| 304 |
+
| **Recall@K** | Proportion of relevant items found in top-K |
|
| 305 |
+
| **NDCG@K** | Normalized Discounted Cumulative Gain (rewards early relevant items) |
|
| 306 |
+
| **MSE** | Mean Squared Error of taste profile predictions |
|
| 307 |
+
| **MAE** | Mean Absolute Error of taste profile predictions |
|
| 308 |
+
|
| 309 |
+
**Relevance definition:** A coffee is relevant if it shares the same country AND processing method as the query, OR has
|
| 310 |
+
cosine similarity >= 0.95.
|
| 311 |
+
|
| 312 |
+
### Running Evaluation
|
| 313 |
+
|
| 314 |
+
```bash
|
| 315 |
+
uv run evaluate [--models all] [--error-analysis] [--output results.json]
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
| Option | Description |
|
| 319 |
+
|--------------------|-----------------------------------------------------------------|
|
| 320 |
+
| `--models` | Models to evaluate: `baseline`, `classical`, `neural`, or `all` |
|
| 321 |
+
| `--error-analysis` | Generate detailed error analysis |
|
| 322 |
+
| `--output` | Save results to JSON file |
|
| 323 |
+
|
| 324 |
+
### Error Analysis
|
| 325 |
+
|
| 326 |
+
The error analysis module identifies:
|
| 327 |
+
|
| 328 |
+
1. **5 Worst Mispredictions** with root cause analysis:
|
| 329 |
+
- Origin mismatch
|
| 330 |
+
- Processing method mismatch
|
| 331 |
+
- Large taste profile deviations
|
| 332 |
+
|
| 333 |
+
2. **Common Error Patterns**:
|
| 334 |
+
- Failures by country of origin
|
| 335 |
+
- Failures by processing method
|
| 336 |
+
- Cross-origin confusion (e.g., confusing Ethiopia with Kenya)
|
| 337 |
+
- Taste profile edge cases (high acidity, low body)
|
| 338 |
+
|
| 339 |
+
3. **Mitigation Strategies**:
|
| 340 |
+
- Origin-aware embeddings
|
| 341 |
+
- Processing method features
|
| 342 |
+
- Contrastive learning for confused origins
|
| 343 |
+
- Re-ranking stages
|
| 344 |
+
|
| 345 |
+
## Experiment: Sensitivity Analysis
|
| 346 |
+
|
| 347 |
+
Investigates how model performance varies with training set size.
|
| 348 |
+
|
| 349 |
+
### Hypothesis
|
| 350 |
+
|
| 351 |
+
Deep learning models benefit more from additional data, while classical models plateau earlier.
|
| 352 |
+
|
| 353 |
+
### Running the Experiment
|
| 354 |
+
|
| 355 |
+
```bash
|
| 356 |
+
uv run experiment [--fractions 0.1 0.2 ... 1.0] [--trials 3] [--device cuda]
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
| Option | Description |
|
| 360 |
+
|----------------|----------------------------------------------------------|
|
| 361 |
+
| `--fractions` | Training set fractions to test (default: 0.1 to 1.0) |
|
| 362 |
+
| `--trials` | Trials per fraction for variance estimation (default: 3) |
|
| 363 |
+
| `--device` | PyTorch device |
|
| 364 |
+
| `--output-dir` | Directory for results (default: `experiments/`) |
|
| 365 |
+
|
| 366 |
+
### Outputs
|
| 367 |
+
|
| 368 |
+
- `raw_results.json` - Per-trial metrics
|
| 369 |
+
- `aggregated_results.csv` - Mean and std per model/fraction
|
| 370 |
+
- `sensitivity_analysis.png` - Performance vs. training size plot
|
| 371 |
+
- `sensitivity_analysis_multi.png` - Multi-metric comparison
|
| 372 |
+
- `experiment_report.txt` - Text summary with findings
|
| 373 |
+
|
| 374 |
+
## API Reference
|
| 375 |
+
|
| 376 |
+
### Starting the Server
|
| 377 |
+
|
| 378 |
+
```bash
|
| 379 |
+
uv run serve
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
Or with environment variables:
|
| 383 |
+
|
| 384 |
+
```bash
|
| 385 |
+
FLASK_HOST=0.0.0.0 FLASK_PORT=8000 FLASK_DEBUG=true uv run serve
|
| 386 |
+
```
|
| 387 |
+
|
| 388 |
+
### Endpoints
|
| 389 |
+
|
| 390 |
+
#### Health Check
|
| 391 |
+
|
| 392 |
+
```http
|
| 393 |
+
GET /health
|
| 394 |
+
```
|
| 395 |
+
|
| 396 |
+
**Response:**
|
| 397 |
+
|
| 398 |
+
```json
|
| 399 |
+
{
|
| 400 |
+
"status": "healthy",
|
| 401 |
+
"models_loaded": 3,
|
| 402 |
+
"available_models": [
|
| 403 |
+
"baseline",
|
| 404 |
+
"classical",
|
| 405 |
+
"neural"
|
| 406 |
+
]
|
| 407 |
+
}
|
| 408 |
+
```
|
| 409 |
+
|
| 410 |
+
#### List Models
|
| 411 |
+
|
| 412 |
+
```http
|
| 413 |
+
GET /api/models
|
| 414 |
+
```
|
| 415 |
+
|
| 416 |
+
**Response:**
|
| 417 |
+
|
| 418 |
+
```json
|
| 419 |
+
{
|
| 420 |
+
"models": [
|
| 421 |
+
{
|
| 422 |
+
"name": "baseline",
|
| 423 |
+
"available": true,
|
| 424 |
+
"is_fitted": true
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"name": "classical",
|
| 428 |
+
"available": true,
|
| 429 |
+
"is_fitted": true
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"name": "neural",
|
| 433 |
+
"available": true,
|
| 434 |
+
"is_fitted": true
|
| 435 |
+
}
|
| 436 |
+
]
|
| 437 |
+
}
|
| 438 |
+
```
|
| 439 |
+
|
| 440 |
+
#### Get Recommendations
|
| 441 |
+
|
| 442 |
+
```http
|
| 443 |
+
POST /api/recommend
|
| 444 |
+
Content-Type: application/json
|
| 445 |
+
|
| 446 |
+
{
|
| 447 |
+
"preferences": {
|
| 448 |
+
"aroma": 8.0,
|
| 449 |
+
"flavor": 7.5,
|
| 450 |
+
"aftertaste": 7.0,
|
| 451 |
+
"acidity": 7.5,
|
| 452 |
+
"body": 8.0,
|
| 453 |
+
"balance": 7.5,
|
| 454 |
+
"uniformity": 10.0,
|
| 455 |
+
"clean_cup": 10.0,
|
| 456 |
+
"sweetness": 10.0
|
| 457 |
+
},
|
| 458 |
+
"model": "neural",
|
| 459 |
+
"k": 5
|
| 460 |
+
}
|
| 461 |
+
```
|
| 462 |
+
|
| 463 |
+
**Response:**
|
| 464 |
+
|
| 465 |
+
```json
|
| 466 |
+
{
|
| 467 |
+
"recommendations": [
|
| 468 |
+
{
|
| 469 |
+
"id": 42,
|
| 470 |
+
"similarity": 0.95,
|
| 471 |
+
"scores": {
|
| 472 |
+
"aroma": 7.92,
|
| 473 |
+
"flavor": 7.58
|
| 474 |
+
},
|
| 475 |
+
"country": "Ethiopia",
|
| 476 |
+
"metadata": {}
|
| 477 |
+
}
|
| 478 |
+
],
|
| 479 |
+
"model_used": "neural",
|
| 480 |
+
"k": 5
|
| 481 |
+
}
|
| 482 |
+
```
|
| 483 |
+
|
| 484 |
+
| Field | Type | Description |
|
| 485 |
+
|---------------|---------|--------------------------------------------------------------------|
|
| 486 |
+
| `preferences` | object | Required. All 9 taste features (0-10 scale) |
|
| 487 |
+
| `model` | string | Optional. `baseline`, `classical`, or `neural` (default: `neural`) |
|
| 488 |
+
| `k` | integer | Optional. Number of recommendations (1-100, default: 5) |
|
| 489 |
+
|
| 490 |
+
#### Get Coffee Details
|
| 491 |
+
|
| 492 |
+
```http
|
| 493 |
+
GET /api/coffee/{id}
|
| 494 |
+
```
|
| 495 |
+
|
| 496 |
+
**Response:**
|
| 497 |
+
|
| 498 |
+
```json
|
| 499 |
+
{
|
| 500 |
+
"id": 42,
|
| 501 |
+
"metadata": {
|
| 502 |
+
"Country.of.Origin": "Ethiopia",
|
| 503 |
+
"Processing.Method": "Washed / Wet"
|
| 504 |
+
},
|
| 505 |
+
"taste_profile": {
|
| 506 |
+
"aroma": 7.92
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
```
|
| 510 |
+
|
| 511 |
+
#### Get Statistics
|
| 512 |
+
|
| 513 |
+
```http
|
| 514 |
+
GET /api/stats
|
| 515 |
+
```
|
| 516 |
+
|
| 517 |
+
**Response:**
|
| 518 |
+
|
| 519 |
+
```json
|
| 520 |
+
{
|
| 521 |
+
"total_coffees": 1200,
|
| 522 |
+
"models": {
|
| 523 |
+
"baseline": {
|
| 524 |
+
"is_fitted": true,
|
| 525 |
+
"training_samples": 960
|
| 526 |
+
},
|
| 527 |
+
"classical": {
|
| 528 |
+
"is_fitted": true,
|
| 529 |
+
"training_samples": 960
|
| 530 |
+
},
|
| 531 |
+
"neural": {
|
| 532 |
+
"is_fitted": true,
|
| 533 |
+
"training_samples": 960
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
```
|
| 538 |
+
|
| 539 |
+
### Error Responses
|
| 540 |
+
|
| 541 |
+
| Status | Description |
|
| 542 |
+
|--------|-------------------------------------------|
|
| 543 |
+
| 400 | Validation error (missing/invalid fields) |
|
| 544 |
+
| 404 | Resource not found |
|
| 545 |
+
| 503 | No models loaded |
|
| 546 |
+
| 500 | Internal server error |
|
| 547 |
+
|
| 548 |
+
## Deployment
|
| 549 |
+
|
| 550 |
+
### Production with Gunicorn
|
| 551 |
+
|
| 552 |
+
```bash
|
| 553 |
+
uv run gunicorn "brewmatch.api.app:create_app()" \
|
| 554 |
+
--bind 0.0.0.0:8000 \
|
| 555 |
+
--workers 4 \
|
| 556 |
+
--timeout 120
|
| 557 |
+
```
|
| 558 |
+
|
| 559 |
+
### Docker
|
| 560 |
+
|
| 561 |
+
```dockerfile
|
| 562 |
+
FROM python:3.13-slim
|
| 563 |
+
|
| 564 |
+
WORKDIR /app
|
| 565 |
+
COPY . .
|
| 566 |
+
|
| 567 |
+
RUN pip install uv && uv sync --frozen
|
| 568 |
+
|
| 569 |
+
# Download and preprocess data, train models
|
| 570 |
+
RUN uv run download && uv run preprocess && uv run train
|
| 571 |
+
|
| 572 |
+
EXPOSE 8000
|
| 573 |
+
CMD ["uv", "run", "gunicorn", "brewmatch.api.app:create_app()", "--bind", "0.0.0.0:8000"]
|
| 574 |
+
```
|
| 575 |
+
|
| 576 |
+
### Environment Variables
|
| 577 |
+
|
| 578 |
+
| Variable | Description | Default |
|
| 579 |
+
|---------------|---------------------|-------------|
|
| 580 |
+
| `FLASK_HOST` | Server bind address | `127.0.0.1` |
|
| 581 |
+
| `FLASK_PORT` | Server port | `5000` |
|
| 582 |
+
| `FLASK_DEBUG` | Enable debug mode | `false` |
|
| 583 |
+
|
| 584 |
+
---
|
| 585 |
+
|
| 586 |
+
**Dataset:** [Coffee Quality Database (CQI)](https://www.kaggle.com/datasets/volpatto/coffee-quality-database-from-cqi) by Diego Volpatto
|
pyproject.toml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "brewmatch"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Coffee recommendation system using ML - recommends coffee beans based on taste preferences"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.13"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"flask>=3.1.3",
|
| 9 |
+
"flask-cors>=6.0.2",
|
| 10 |
+
"gunicorn>=25.3.0",
|
| 11 |
+
"kagglehub>=1.0.0",
|
| 12 |
+
"matplotlib>=3.10.8",
|
| 13 |
+
"numpy>=2.4.3",
|
| 14 |
+
"optuna>=4.8.0",
|
| 15 |
+
"pandas>=3.0.2",
|
| 16 |
+
"pyarrow>=23.0.1",
|
| 17 |
+
"scikit-learn>=1.8.0",
|
| 18 |
+
"seaborn>=0.13.2",
|
| 19 |
+
"tabulate>=0.10.0",
|
| 20 |
+
"torch>=2.11.0",
|
| 21 |
+
"tqdm>=4.66.5",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.optional-dependencies]
|
| 25 |
+
cuda = ["torch>=2.11.0"]
|
| 26 |
+
|
| 27 |
+
[project.scripts]
|
| 28 |
+
download = "brewmatch.data.download:main"
|
| 29 |
+
preprocess = "brewmatch.data.preprocess:main"
|
| 30 |
+
train = "brewmatch.train:main"
|
| 31 |
+
evaluate = "brewmatch.evaluate:main"
|
| 32 |
+
experiment = "brewmatch.experiment:main"
|
| 33 |
+
serve = "brewmatch.api.app:main"
|
| 34 |
+
|
| 35 |
+
[build-system]
|
| 36 |
+
requires = ["hatchling"]
|
| 37 |
+
build-backend = "hatchling.build"
|
| 38 |
+
|
| 39 |
+
[tool.hatch.build.targets.wheel]
|
| 40 |
+
packages = ["src/brewmatch"]
|
| 41 |
+
|
| 42 |
+
[[tool.uv.index]]
|
| 43 |
+
name = "pytorch-cu130"
|
| 44 |
+
url = "https://download.pytorch.org/whl/cu130"
|
| 45 |
+
|
| 46 |
+
[tool.uv.sources]
|
| 47 |
+
torch = [
|
| 48 |
+
{ index = "pytorch-cu130", extra = "cuda" },
|
| 49 |
+
]
|
src/brewmatch/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BrewMatch - Coffee Recommendation System using Machine Learning."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
src/brewmatch/api/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flask API for coffee recommendations."""
|
| 2 |
+
|
| 3 |
+
from .app import create_app, main
|
| 4 |
+
from .schemas import (
|
| 5 |
+
TASTE_FEATURES,
|
| 6 |
+
VALID_MODELS,
|
| 7 |
+
ValidationError,
|
| 8 |
+
validate_preferences,
|
| 9 |
+
validate_model_name,
|
| 10 |
+
validate_k,
|
| 11 |
+
validate_coffee_id,
|
| 12 |
+
validate_recommend_request,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"create_app",
|
| 17 |
+
"main",
|
| 18 |
+
"TASTE_FEATURES",
|
| 19 |
+
"VALID_MODELS",
|
| 20 |
+
"ValidationError",
|
| 21 |
+
"validate_preferences",
|
| 22 |
+
"validate_model_name",
|
| 23 |
+
"validate_k",
|
| 24 |
+
"validate_coffee_id",
|
| 25 |
+
"validate_recommend_request",
|
| 26 |
+
]
|
src/brewmatch/api/app.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Flask API for BrewMatch coffee recommendations."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from flask import Flask, jsonify, request
|
| 10 |
+
from flask_cors import CORS
|
| 11 |
+
|
| 12 |
+
from brewmatch.models import (
|
| 13 |
+
ClassicalMLRecommender,
|
| 14 |
+
NaiveBaselineRecommender,
|
| 15 |
+
NeuralRecommender,
|
| 16 |
+
)
|
| 17 |
+
from brewmatch.models.base import BaseRecommender
|
| 18 |
+
|
| 19 |
+
from .schemas import (
|
| 20 |
+
TASTE_FEATURES,
|
| 21 |
+
VALID_MODELS,
|
| 22 |
+
ValidationError,
|
| 23 |
+
validate_coffee_id,
|
| 24 |
+
validate_recommend_request,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Configure logging
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
level=logging.INFO,
|
| 30 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 31 |
+
)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Model type mapping
|
| 35 |
+
MODEL_CLASSES: dict[str, type[BaseRecommender]] = {
|
| 36 |
+
"baseline": NaiveBaselineRecommender,
|
| 37 |
+
"classical": ClassicalMLRecommender,
|
| 38 |
+
"neural": NeuralRecommender,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# File extensions for each model type
|
| 42 |
+
MODEL_EXTENSIONS: dict[str, str] = {
|
| 43 |
+
"baseline": ".pkl",
|
| 44 |
+
"classical": ".pkl",
|
| 45 |
+
"neural": ".pt",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_models(checkpoint_dir: Path) -> dict[str, BaseRecommender]:
|
| 50 |
+
"""Load all available models from the checkpoint directory.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
checkpoint_dir: Path to the directory containing model checkpoints.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Dictionary mapping model names to loaded model instances.
|
| 57 |
+
"""
|
| 58 |
+
models: dict[str, BaseRecommender] = {}
|
| 59 |
+
|
| 60 |
+
if not checkpoint_dir.exists():
|
| 61 |
+
logger.warning(f"Checkpoint directory does not exist: {checkpoint_dir}")
|
| 62 |
+
return models
|
| 63 |
+
|
| 64 |
+
for model_name, model_class in MODEL_CLASSES.items():
|
| 65 |
+
extension = MODEL_EXTENSIONS[model_name]
|
| 66 |
+
model_path = checkpoint_dir / f"{model_name}{extension}"
|
| 67 |
+
|
| 68 |
+
if model_path.exists():
|
| 69 |
+
try:
|
| 70 |
+
logger.info(f"Loading {model_name} model from {model_path}")
|
| 71 |
+
models[model_name] = model_class.load(model_path)
|
| 72 |
+
logger.info(f"Successfully loaded {model_name} model")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Failed to load {model_name} model: {e}")
|
| 75 |
+
else:
|
| 76 |
+
logger.info(f"No checkpoint found for {model_name} at {model_path}")
|
| 77 |
+
|
| 78 |
+
return models
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def create_app(config: dict[str, Any] | None = None) -> Flask:
|
| 82 |
+
"""Create and configure the Flask application.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
config: Optional configuration dictionary. Supported keys:
|
| 86 |
+
- CHECKPOINT_DIR: Path to model checkpoints directory.
|
| 87 |
+
- TESTING: Enable testing mode.
|
| 88 |
+
- DEBUG: Enable debug mode.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Configured Flask application instance.
|
| 92 |
+
"""
|
| 93 |
+
app = Flask(__name__)
|
| 94 |
+
|
| 95 |
+
# Apply configuration
|
| 96 |
+
if config:
|
| 97 |
+
app.config.update(config)
|
| 98 |
+
|
| 99 |
+
# Enable CORS for all routes
|
| 100 |
+
CORS(app)
|
| 101 |
+
|
| 102 |
+
# Determine checkpoint directory
|
| 103 |
+
checkpoint_dir = app.config.get("CHECKPOINT_DIR")
|
| 104 |
+
if checkpoint_dir:
|
| 105 |
+
checkpoint_dir = Path(checkpoint_dir)
|
| 106 |
+
else:
|
| 107 |
+
# Default to models/checkpoints relative to project root
|
| 108 |
+
# Path: app.py -> api -> brewmatch -> src -> project_root
|
| 109 |
+
checkpoint_dir = Path(__file__).parent.parent.parent.parent / "models" / "checkpoints"
|
| 110 |
+
|
| 111 |
+
# Load models on startup
|
| 112 |
+
app.models: dict[str, BaseRecommender] = load_models(checkpoint_dir)
|
| 113 |
+
|
| 114 |
+
# Store coffee data reference (populated when first model is loaded)
|
| 115 |
+
app.coffee_data: dict[int, dict[str, Any]] = {}
|
| 116 |
+
|
| 117 |
+
# Build coffee data index from loaded models
|
| 118 |
+
if app.models:
|
| 119 |
+
first_model = next(iter(app.models.values()))
|
| 120 |
+
if hasattr(first_model, "_metadata") and first_model._metadata is not None:
|
| 121 |
+
for idx in range(len(first_model._metadata)):
|
| 122 |
+
row = first_model._metadata.iloc[idx]
|
| 123 |
+
app.coffee_data[idx] = {
|
| 124 |
+
"id": idx,
|
| 125 |
+
"metadata": row.to_dict(),
|
| 126 |
+
}
|
| 127 |
+
# Add taste profile if available
|
| 128 |
+
if hasattr(first_model, "_X") and first_model._X is not None:
|
| 129 |
+
taste_profile = first_model._X[idx]
|
| 130 |
+
app.coffee_data[idx]["taste_profile"] = {
|
| 131 |
+
feature.lower().replace(" ", "_"): float(taste_profile[i])
|
| 132 |
+
for i, feature in enumerate(BaseRecommender.TASTE_FEATURES)
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
@app.errorhandler(ValidationError)
|
| 136 |
+
def handle_validation_error(error: ValidationError):
|
| 137 |
+
"""Handle validation errors with proper JSON response."""
|
| 138 |
+
response = {"error": error.message}
|
| 139 |
+
if error.field:
|
| 140 |
+
response["field"] = error.field
|
| 141 |
+
return jsonify(response), 400
|
| 142 |
+
|
| 143 |
+
@app.errorhandler(404)
|
| 144 |
+
def handle_not_found(error):
|
| 145 |
+
"""Handle 404 errors."""
|
| 146 |
+
return jsonify({"error": "Resource not found"}), 404
|
| 147 |
+
|
| 148 |
+
@app.errorhandler(500)
|
| 149 |
+
def handle_internal_error(error):
|
| 150 |
+
"""Handle internal server errors."""
|
| 151 |
+
logger.exception("Internal server error")
|
| 152 |
+
return jsonify({"error": "Internal server error"}), 500
|
| 153 |
+
|
| 154 |
+
@app.route("/health", methods=["GET"])
|
| 155 |
+
def health_check():
|
| 156 |
+
"""Health check endpoint.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
JSON response with status and loaded models count.
|
| 160 |
+
"""
|
| 161 |
+
return jsonify({
|
| 162 |
+
"status": "healthy",
|
| 163 |
+
"models_loaded": len(app.models),
|
| 164 |
+
"available_models": list(app.models.keys()),
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
@app.route("/api/models", methods=["GET"])
|
| 168 |
+
def list_models():
|
| 169 |
+
"""List available recommendation models.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
JSON response with list of available models and their status.
|
| 173 |
+
"""
|
| 174 |
+
models_info = []
|
| 175 |
+
for model_name in VALID_MODELS:
|
| 176 |
+
model_info = {
|
| 177 |
+
"name": model_name,
|
| 178 |
+
"available": model_name in app.models,
|
| 179 |
+
}
|
| 180 |
+
if model_name in app.models:
|
| 181 |
+
model = app.models[model_name]
|
| 182 |
+
model_info["is_fitted"] = model.is_fitted
|
| 183 |
+
models_info.append(model_info)
|
| 184 |
+
|
| 185 |
+
return jsonify({"models": models_info})
|
| 186 |
+
|
| 187 |
+
@app.route("/api/recommend", methods=["POST"])
|
| 188 |
+
def get_recommendations():
|
| 189 |
+
"""Get coffee recommendations based on taste preferences.
|
| 190 |
+
|
| 191 |
+
Request body:
|
| 192 |
+
{
|
| 193 |
+
"preferences": {
|
| 194 |
+
"aroma": 8.0,
|
| 195 |
+
"flavor": 7.5,
|
| 196 |
+
"aftertaste": 7.0,
|
| 197 |
+
"acidity": 7.5,
|
| 198 |
+
"body": 8.0,
|
| 199 |
+
"balance": 7.5,
|
| 200 |
+
"uniformity": 10.0,
|
| 201 |
+
"clean_cup": 10.0,
|
| 202 |
+
"sweetness": 10.0
|
| 203 |
+
},
|
| 204 |
+
"model": "neural",
|
| 205 |
+
"k": 5
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
JSON response with list of recommended coffees.
|
| 210 |
+
"""
|
| 211 |
+
data = request.get_json(silent=True)
|
| 212 |
+
validated = validate_recommend_request(data)
|
| 213 |
+
|
| 214 |
+
model_name = validated["model"]
|
| 215 |
+
preferences = validated["preferences"]
|
| 216 |
+
k = validated["k"]
|
| 217 |
+
|
| 218 |
+
# Check if requested model is available
|
| 219 |
+
if model_name not in app.models:
|
| 220 |
+
available = list(app.models.keys())
|
| 221 |
+
if not available:
|
| 222 |
+
return jsonify({
|
| 223 |
+
"error": "No models are currently loaded",
|
| 224 |
+
}), 503
|
| 225 |
+
return jsonify({
|
| 226 |
+
"error": f"Model '{model_name}' is not available",
|
| 227 |
+
"available_models": available,
|
| 228 |
+
}), 400
|
| 229 |
+
|
| 230 |
+
model = app.models[model_name]
|
| 231 |
+
|
| 232 |
+
# Convert preferences dict to numpy array in correct order
|
| 233 |
+
# Map API field names to model feature names
|
| 234 |
+
feature_mapping = {
|
| 235 |
+
"aroma": "Aroma",
|
| 236 |
+
"flavor": "Flavor",
|
| 237 |
+
"aftertaste": "Aftertaste",
|
| 238 |
+
"acidity": "Acidity",
|
| 239 |
+
"body": "Body",
|
| 240 |
+
"balance": "Balance",
|
| 241 |
+
"uniformity": "Uniformity",
|
| 242 |
+
"clean_cup": "Clean Cup",
|
| 243 |
+
"sweetness": "Sweetness",
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
preferences_array = np.array([
|
| 247 |
+
preferences[feature.lower().replace(" ", "_")]
|
| 248 |
+
for feature in BaseRecommender.TASTE_FEATURES
|
| 249 |
+
], dtype=np.float32)
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
recommendations = model.recommend(preferences_array, k=k)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.exception("Error generating recommendations")
|
| 255 |
+
return jsonify({"error": f"Failed to generate recommendations: {str(e)}"}), 500
|
| 256 |
+
|
| 257 |
+
# Format response
|
| 258 |
+
formatted_recommendations = []
|
| 259 |
+
for rec in recommendations:
|
| 260 |
+
formatted_rec = {
|
| 261 |
+
"id": rec["index"],
|
| 262 |
+
"similarity": rec["score"],
|
| 263 |
+
"scores": {
|
| 264 |
+
key.lower().replace(" ", "_"): value
|
| 265 |
+
for key, value in rec["taste_profile"].items()
|
| 266 |
+
},
|
| 267 |
+
}
|
| 268 |
+
# Add metadata fields at top level for convenience
|
| 269 |
+
if rec.get("metadata"):
|
| 270 |
+
formatted_rec["country"] = rec["metadata"].get("Country of Origin", "Unknown")
|
| 271 |
+
formatted_rec["metadata"] = rec["metadata"]
|
| 272 |
+
|
| 273 |
+
formatted_recommendations.append(formatted_rec)
|
| 274 |
+
|
| 275 |
+
return jsonify({
|
| 276 |
+
"recommendations": formatted_recommendations,
|
| 277 |
+
"model_used": model_name,
|
| 278 |
+
"k": k,
|
| 279 |
+
})
|
| 280 |
+
|
| 281 |
+
@app.route("/api/coffee/<int:coffee_id>", methods=["GET"])
|
| 282 |
+
def get_coffee(coffee_id: int):
|
| 283 |
+
"""Get details for a specific coffee by ID.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
coffee_id: The ID of the coffee to retrieve.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
JSON response with coffee details.
|
| 290 |
+
"""
|
| 291 |
+
validated_id = validate_coffee_id(coffee_id)
|
| 292 |
+
|
| 293 |
+
if validated_id not in app.coffee_data:
|
| 294 |
+
return jsonify({"error": f"Coffee with id {validated_id} not found"}), 404
|
| 295 |
+
|
| 296 |
+
coffee = app.coffee_data[validated_id]
|
| 297 |
+
return jsonify(coffee)
|
| 298 |
+
|
| 299 |
+
@app.route("/api/stats", methods=["GET"])
|
| 300 |
+
def get_stats():
|
| 301 |
+
"""Get model performance statistics.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
JSON response with statistics about loaded models and data.
|
| 305 |
+
"""
|
| 306 |
+
stats: dict[str, Any] = {
|
| 307 |
+
"total_coffees": len(app.coffee_data),
|
| 308 |
+
"models": {},
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
for model_name, model in app.models.items():
|
| 312 |
+
model_stats: dict[str, Any] = {
|
| 313 |
+
"is_fitted": model.is_fitted,
|
| 314 |
+
}
|
| 315 |
+
if hasattr(model, "_X") and model._X is not None:
|
| 316 |
+
model_stats["training_samples"] = len(model._X)
|
| 317 |
+
if hasattr(model, "_metadata") and model._metadata is not None:
|
| 318 |
+
model_stats["metadata_columns"] = list(model._metadata.columns)
|
| 319 |
+
stats["models"][model_name] = model_stats
|
| 320 |
+
|
| 321 |
+
return jsonify(stats)
|
| 322 |
+
|
| 323 |
+
return app
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def main() -> None:
|
| 327 |
+
"""Entry point for running the Flask development server.
|
| 328 |
+
|
| 329 |
+
This function is called by `uv run serve`. For production deployments,
|
| 330 |
+
use a WSGI server like gunicorn instead.
|
| 331 |
+
"""
|
| 332 |
+
host = os.environ.get("FLASK_HOST", "127.0.0.1")
|
| 333 |
+
port = int(os.environ.get("FLASK_PORT", "5000"))
|
| 334 |
+
debug = os.environ.get("FLASK_DEBUG", "false").lower() == "true"
|
| 335 |
+
|
| 336 |
+
logger.info(f"Starting BrewMatch API server on {host}:{port}")
|
| 337 |
+
|
| 338 |
+
app = create_app()
|
| 339 |
+
app.run(host=host, port=port, debug=debug)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
if __name__ == "__main__":
|
| 343 |
+
main()
|
src/brewmatch/api/schemas.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Request/response validation schemas for the BrewMatch API."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
# The 9 taste preference features
|
| 6 |
+
TASTE_FEATURES = [
|
| 7 |
+
"aroma",
|
| 8 |
+
"flavor",
|
| 9 |
+
"aftertaste",
|
| 10 |
+
"acidity",
|
| 11 |
+
"body",
|
| 12 |
+
"balance",
|
| 13 |
+
"uniformity",
|
| 14 |
+
"clean_cup",
|
| 15 |
+
"sweetness",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
# Valid model names
|
| 19 |
+
VALID_MODELS = ["baseline", "classical", "neural"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ValidationError(Exception):
|
| 23 |
+
"""Raised when request validation fails."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, message: str, field: str | None = None) -> None:
|
| 26 |
+
self.message = message
|
| 27 |
+
self.field = field
|
| 28 |
+
super().__init__(message)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def validate_preferences(preferences: dict[str, Any] | None) -> dict[str, float]:
|
| 32 |
+
"""Validate taste preferences from API request.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
preferences: Dictionary of taste preferences with feature names as keys
|
| 36 |
+
and scores as values. All 9 features must be present.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Validated preferences dictionary with float values.
|
| 40 |
+
|
| 41 |
+
Raises:
|
| 42 |
+
ValidationError: If preferences are invalid.
|
| 43 |
+
"""
|
| 44 |
+
if preferences is None:
|
| 45 |
+
raise ValidationError("preferences is required", field="preferences")
|
| 46 |
+
|
| 47 |
+
if not isinstance(preferences, dict):
|
| 48 |
+
raise ValidationError(
|
| 49 |
+
"preferences must be an object", field="preferences"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
validated: dict[str, float] = {}
|
| 53 |
+
missing_fields = []
|
| 54 |
+
|
| 55 |
+
for feature in TASTE_FEATURES:
|
| 56 |
+
if feature not in preferences:
|
| 57 |
+
missing_fields.append(feature)
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
value = preferences[feature]
|
| 61 |
+
|
| 62 |
+
if value is None:
|
| 63 |
+
raise ValidationError(
|
| 64 |
+
f"{feature} cannot be null", field=f"preferences.{feature}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
float_value = float(value)
|
| 69 |
+
except (TypeError, ValueError):
|
| 70 |
+
raise ValidationError(
|
| 71 |
+
f"{feature} must be a number", field=f"preferences.{feature}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if float_value < 0.0 or float_value > 10.0:
|
| 75 |
+
raise ValidationError(
|
| 76 |
+
f"{feature} must be between 0 and 10", field=f"preferences.{feature}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
validated[feature] = float_value
|
| 80 |
+
|
| 81 |
+
if missing_fields:
|
| 82 |
+
raise ValidationError(
|
| 83 |
+
f"Missing required fields: {', '.join(missing_fields)}",
|
| 84 |
+
field="preferences",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return validated
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def validate_model_name(model: str | None) -> str:
|
| 91 |
+
"""Validate model name from API request.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
model: Name of the model to use. Must be one of 'baseline', 'classical',
|
| 95 |
+
or 'neural'. Defaults to 'neural' if not provided.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Validated model name.
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
ValidationError: If model name is invalid.
|
| 102 |
+
"""
|
| 103 |
+
if model is None:
|
| 104 |
+
return "neural"
|
| 105 |
+
|
| 106 |
+
if not isinstance(model, str):
|
| 107 |
+
raise ValidationError("model must be a string", field="model")
|
| 108 |
+
|
| 109 |
+
model = model.lower().strip()
|
| 110 |
+
|
| 111 |
+
if model not in VALID_MODELS:
|
| 112 |
+
raise ValidationError(
|
| 113 |
+
f"model must be one of: {', '.join(VALID_MODELS)}", field="model"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return model
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def validate_k(k: Any | None) -> int:
|
| 120 |
+
"""Validate k (number of recommendations) from API request.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
k: Number of recommendations to return. Must be a positive integer
|
| 124 |
+
between 1 and 100. Defaults to 5 if not provided.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Validated k value.
|
| 128 |
+
|
| 129 |
+
Raises:
|
| 130 |
+
ValidationError: If k is invalid.
|
| 131 |
+
"""
|
| 132 |
+
if k is None:
|
| 133 |
+
return 5
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
k_int = int(k)
|
| 137 |
+
except (TypeError, ValueError):
|
| 138 |
+
raise ValidationError("k must be an integer", field="k")
|
| 139 |
+
|
| 140 |
+
if k_int < 1:
|
| 141 |
+
raise ValidationError("k must be at least 1", field="k")
|
| 142 |
+
|
| 143 |
+
if k_int > 100:
|
| 144 |
+
raise ValidationError("k must be at most 100", field="k")
|
| 145 |
+
|
| 146 |
+
return k_int
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def validate_coffee_id(coffee_id: Any) -> int:
|
| 150 |
+
"""Validate coffee ID from API request.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
coffee_id: ID of the coffee to retrieve.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Validated coffee ID as integer.
|
| 157 |
+
|
| 158 |
+
Raises:
|
| 159 |
+
ValidationError: If coffee ID is invalid.
|
| 160 |
+
"""
|
| 161 |
+
if coffee_id is None:
|
| 162 |
+
raise ValidationError("coffee id is required", field="id")
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
id_int = int(coffee_id)
|
| 166 |
+
except (TypeError, ValueError):
|
| 167 |
+
raise ValidationError("coffee id must be an integer", field="id")
|
| 168 |
+
|
| 169 |
+
if id_int < 0:
|
| 170 |
+
raise ValidationError("coffee id must be non-negative", field="id")
|
| 171 |
+
|
| 172 |
+
return id_int
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def validate_recommend_request(data: dict[str, Any] | None) -> dict[str, Any]:
|
| 176 |
+
"""Validate the full recommendation request body.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
data: Request body dictionary containing preferences, model, and k.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Validated request data with all fields populated.
|
| 183 |
+
|
| 184 |
+
Raises:
|
| 185 |
+
ValidationError: If request data is invalid.
|
| 186 |
+
"""
|
| 187 |
+
if data is None:
|
| 188 |
+
raise ValidationError("Request body is required")
|
| 189 |
+
|
| 190 |
+
if not isinstance(data, dict):
|
| 191 |
+
raise ValidationError("Request body must be a JSON object")
|
| 192 |
+
|
| 193 |
+
return {
|
| 194 |
+
"preferences": validate_preferences(data.get("preferences")),
|
| 195 |
+
"model": validate_model_name(data.get("model")),
|
| 196 |
+
"k": validate_k(data.get("k")),
|
| 197 |
+
}
|
src/brewmatch/config.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration settings for BrewMatch."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# Project paths
|
| 6 |
+
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 7 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 8 |
+
RAW_DATA_DIR = DATA_DIR / "raw"
|
| 9 |
+
PROCESSED_DATA_DIR = DATA_DIR / "processed"
|
| 10 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 11 |
+
CHECKPOINTS_DIR = MODELS_DIR / "checkpoints"
|
| 12 |
+
|
| 13 |
+
# Ensure directories exist
|
| 14 |
+
for dir_path in [RAW_DATA_DIR, PROCESSED_DATA_DIR, CHECKPOINTS_DIR]:
|
| 15 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
# Dataset settings
|
| 18 |
+
KAGGLE_DATASET = "fatihb/coffee-quality-data-cqi"
|
| 19 |
+
RANDOM_SEED = 42
|
| 20 |
+
TEST_SIZE = 0.2
|
| 21 |
+
VAL_SIZE = 0.1
|
| 22 |
+
|
| 23 |
+
# Feature columns (taste profile) - using actual CSV column names
|
| 24 |
+
TASTE_FEATURES = [
|
| 25 |
+
"Aroma",
|
| 26 |
+
"Flavor",
|
| 27 |
+
"Aftertaste",
|
| 28 |
+
"Acidity",
|
| 29 |
+
"Body",
|
| 30 |
+
"Balance",
|
| 31 |
+
"Uniformity",
|
| 32 |
+
"Clean Cup",
|
| 33 |
+
"Sweetness",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# Metadata columns to preserve
|
| 37 |
+
METADATA_COLUMNS = [
|
| 38 |
+
"Country of Origin",
|
| 39 |
+
"Region",
|
| 40 |
+
"Processing Method",
|
| 41 |
+
"Variety",
|
| 42 |
+
"Color",
|
| 43 |
+
"Total Cup Points",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# Model hyperparameters
|
| 47 |
+
BASELINE_CONFIG = {
|
| 48 |
+
"strategy": "mean_similarity", # or "quality_weighted_random"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
CLASSICAL_CONFIG = {
|
| 52 |
+
"n_neighbors": 10,
|
| 53 |
+
"metric": "cosine",
|
| 54 |
+
"algorithm": "brute",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
NEURAL_CONFIG = {
|
| 58 |
+
"embedding_dim": 32,
|
| 59 |
+
"hidden_dims": [64, 32],
|
| 60 |
+
"learning_rate": 0.001,
|
| 61 |
+
"batch_size": 32,
|
| 62 |
+
"epochs": 100,
|
| 63 |
+
"margin": 0.5, # for triplet loss
|
| 64 |
+
"patience": 10, # early stopping
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
# Evaluation settings
|
| 68 |
+
K_VALUES = [1, 3, 5, 10]
|
| 69 |
+
|
| 70 |
+
# API settings
|
| 71 |
+
API_HOST = "0.0.0.0"
|
| 72 |
+
API_PORT = 5000
|
| 73 |
+
DEBUG = False
|
src/brewmatch/data/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading and preprocessing modules."""
|
| 2 |
+
|
| 3 |
+
from .download import download_data
|
| 4 |
+
from .preprocess import preprocess_data, load_processed_data
|
| 5 |
+
from .dataset import CoffeeDataset, create_dataloaders
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"download_data",
|
| 9 |
+
"preprocess_data",
|
| 10 |
+
"load_processed_data",
|
| 11 |
+
"CoffeeDataset",
|
| 12 |
+
"create_dataloaders",
|
| 13 |
+
]
|
src/brewmatch/data/dataset.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch dataset and dataloaders for coffee quality data."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader, Dataset
|
| 10 |
+
|
| 11 |
+
from .preprocess import TASTE_FEATURES, TARGET_COLUMN, load_processed_data
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CoffeeDataset(Dataset):
|
| 15 |
+
"""
|
| 16 |
+
PyTorch Dataset for coffee quality data.
|
| 17 |
+
|
| 18 |
+
Provides (features, target) pairs where features are the 9 taste
|
| 19 |
+
profile scores and target is the total cup points.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
features: Tensor of shape (n_samples, 9) with normalized taste features.
|
| 23 |
+
targets: Tensor of shape (n_samples,) with total cup points.
|
| 24 |
+
metadata: Optional DataFrame with metadata columns.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
features: np.ndarray | torch.Tensor,
|
| 30 |
+
targets: np.ndarray | torch.Tensor,
|
| 31 |
+
metadata: pd.DataFrame | None = None,
|
| 32 |
+
) -> None:
|
| 33 |
+
"""
|
| 34 |
+
Initialize the dataset.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
features: Array of shape (n_samples, n_features) with input features.
|
| 38 |
+
targets: Array of shape (n_samples,) with target values.
|
| 39 |
+
metadata: Optional DataFrame with metadata (not used in training).
|
| 40 |
+
"""
|
| 41 |
+
if isinstance(features, np.ndarray):
|
| 42 |
+
features = torch.from_numpy(features).float()
|
| 43 |
+
if isinstance(targets, np.ndarray):
|
| 44 |
+
targets = torch.from_numpy(targets).float()
|
| 45 |
+
|
| 46 |
+
self.features = features
|
| 47 |
+
self.targets = targets
|
| 48 |
+
self.metadata = metadata
|
| 49 |
+
|
| 50 |
+
def __len__(self) -> int:
|
| 51 |
+
"""Return the number of samples."""
|
| 52 |
+
return len(self.features)
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 55 |
+
"""
|
| 56 |
+
Get a sample by index.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
idx: Sample index.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Tuple of (features, target) tensors.
|
| 63 |
+
"""
|
| 64 |
+
return self.features[idx], self.targets[idx]
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def from_dataframe(
|
| 68 |
+
cls,
|
| 69 |
+
df: pd.DataFrame,
|
| 70 |
+
feature_cols: list[str] | None = None,
|
| 71 |
+
target_col: str | None = None,
|
| 72 |
+
) -> "CoffeeDataset":
|
| 73 |
+
"""
|
| 74 |
+
Create a dataset from a pandas DataFrame.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
df: DataFrame with features and target.
|
| 78 |
+
feature_cols: List of feature column names (default: TASTE_FEATURES).
|
| 79 |
+
target_col: Target column name (default: TARGET_COLUMN).
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
CoffeeDataset instance.
|
| 83 |
+
"""
|
| 84 |
+
if feature_cols is None:
|
| 85 |
+
feature_cols = TASTE_FEATURES
|
| 86 |
+
if target_col is None:
|
| 87 |
+
target_col = TARGET_COLUMN
|
| 88 |
+
|
| 89 |
+
features = df[feature_cols].values
|
| 90 |
+
targets = df[target_col].values
|
| 91 |
+
|
| 92 |
+
# Get metadata columns (everything that's not a feature or target)
|
| 93 |
+
metadata_cols = [c for c in df.columns if c not in feature_cols and c != target_col]
|
| 94 |
+
metadata = df[metadata_cols] if metadata_cols else None
|
| 95 |
+
|
| 96 |
+
return cls(features, targets, metadata)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_dataloaders(
|
| 100 |
+
batch_size: int = 32,
|
| 101 |
+
val_split: float = 0.1,
|
| 102 |
+
num_workers: int = 0,
|
| 103 |
+
random_state: int = 42,
|
| 104 |
+
) -> dict[str, Any]:
|
| 105 |
+
"""
|
| 106 |
+
Create train, validation, and test DataLoaders.
|
| 107 |
+
|
| 108 |
+
Splits the training data into train/validation sets, keeps test set separate.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
batch_size: Batch size for all loaders (default: 32).
|
| 112 |
+
val_split: Fraction of training data for validation (default: 0.1).
|
| 113 |
+
num_workers: Number of workers for data loading (default: 0).
|
| 114 |
+
random_state: Random seed for train/val split (default: 42).
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Dictionary containing:
|
| 118 |
+
- train_loader: DataLoader for training
|
| 119 |
+
- val_loader: DataLoader for validation
|
| 120 |
+
- test_loader: DataLoader for testing
|
| 121 |
+
- train_dataset: Training CoffeeDataset
|
| 122 |
+
- val_dataset: Validation CoffeeDataset
|
| 123 |
+
- test_dataset: Test CoffeeDataset
|
| 124 |
+
- n_features: Number of input features (9)
|
| 125 |
+
- scaler: The fitted StandardScaler
|
| 126 |
+
|
| 127 |
+
Raises:
|
| 128 |
+
FileNotFoundError: If processed data doesn't exist.
|
| 129 |
+
"""
|
| 130 |
+
# Load processed data
|
| 131 |
+
data = load_processed_data()
|
| 132 |
+
train_df = data["train_df"]
|
| 133 |
+
test_df = data["test_df"]
|
| 134 |
+
scaler = data["scaler"]
|
| 135 |
+
feature_cols = data["taste_features"]
|
| 136 |
+
target_col = data["target_column"]
|
| 137 |
+
|
| 138 |
+
# Split training data into train/val
|
| 139 |
+
n_train = len(train_df)
|
| 140 |
+
n_val = int(n_train * val_split)
|
| 141 |
+
|
| 142 |
+
# Shuffle with fixed seed
|
| 143 |
+
rng = np.random.default_rng(random_state)
|
| 144 |
+
indices = rng.permutation(n_train)
|
| 145 |
+
|
| 146 |
+
val_indices = indices[:n_val]
|
| 147 |
+
train_indices = indices[n_val:]
|
| 148 |
+
|
| 149 |
+
train_subset_df = train_df.iloc[train_indices].reset_index(drop=True)
|
| 150 |
+
val_subset_df = train_df.iloc[val_indices].reset_index(drop=True)
|
| 151 |
+
|
| 152 |
+
# Create datasets
|
| 153 |
+
train_dataset = CoffeeDataset.from_dataframe(
|
| 154 |
+
train_subset_df, feature_cols, target_col
|
| 155 |
+
)
|
| 156 |
+
val_dataset = CoffeeDataset.from_dataframe(
|
| 157 |
+
val_subset_df, feature_cols, target_col
|
| 158 |
+
)
|
| 159 |
+
test_dataset = CoffeeDataset.from_dataframe(
|
| 160 |
+
test_df, feature_cols, target_col
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Create dataloaders
|
| 164 |
+
train_loader = DataLoader(
|
| 165 |
+
train_dataset,
|
| 166 |
+
batch_size=batch_size,
|
| 167 |
+
shuffle=True,
|
| 168 |
+
num_workers=num_workers,
|
| 169 |
+
pin_memory=torch.cuda.is_available(),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
val_loader = DataLoader(
|
| 173 |
+
val_dataset,
|
| 174 |
+
batch_size=batch_size,
|
| 175 |
+
shuffle=False,
|
| 176 |
+
num_workers=num_workers,
|
| 177 |
+
pin_memory=torch.cuda.is_available(),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
test_loader = DataLoader(
|
| 181 |
+
test_dataset,
|
| 182 |
+
batch_size=batch_size,
|
| 183 |
+
shuffle=False,
|
| 184 |
+
num_workers=num_workers,
|
| 185 |
+
pin_memory=torch.cuda.is_available(),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
print(f"Created dataloaders:")
|
| 189 |
+
print(f" Train: {len(train_dataset)} samples, {len(train_loader)} batches")
|
| 190 |
+
print(f" Val: {len(val_dataset)} samples, {len(val_loader)} batches")
|
| 191 |
+
print(f" Test: {len(test_dataset)} samples, {len(test_loader)} batches")
|
| 192 |
+
print(f" Batch size: {batch_size}")
|
| 193 |
+
print(f" Features: {len(feature_cols)}")
|
| 194 |
+
|
| 195 |
+
return {
|
| 196 |
+
"train_loader": train_loader,
|
| 197 |
+
"val_loader": val_loader,
|
| 198 |
+
"test_loader": test_loader,
|
| 199 |
+
"train_dataset": train_dataset,
|
| 200 |
+
"val_dataset": val_dataset,
|
| 201 |
+
"test_dataset": test_dataset,
|
| 202 |
+
"n_features": len(feature_cols),
|
| 203 |
+
"scaler": scaler,
|
| 204 |
+
}
|
src/brewmatch/data/download.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download the CQI coffee quality dataset from Kaggle."""
|
| 2 |
+
|
| 3 |
+
import shutil
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import kagglehub
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_project_root() -> Path:
|
| 10 |
+
"""Get the project root directory (where pyproject.toml is located)."""
|
| 11 |
+
current = Path(__file__).resolve()
|
| 12 |
+
for parent in current.parents:
|
| 13 |
+
if (parent / "pyproject.toml").exists():
|
| 14 |
+
return parent
|
| 15 |
+
raise RuntimeError("Could not find project root (no pyproject.toml found)")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def download_data(force: bool = False) -> Path:
|
| 19 |
+
"""
|
| 20 |
+
Download the CQI coffee quality dataset from Kaggle.
|
| 21 |
+
|
| 22 |
+
Uses kagglehub to download the dataset and copies files to data/raw/.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
force: If True, re-download even if data already exists.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Path to the raw data directory containing the downloaded files.
|
| 29 |
+
|
| 30 |
+
Raises:
|
| 31 |
+
RuntimeError: If download fails or no CSV files are found.
|
| 32 |
+
"""
|
| 33 |
+
project_root = get_project_root()
|
| 34 |
+
raw_dir = project_root / "data" / "raw"
|
| 35 |
+
raw_dir.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# Check if data already exists
|
| 38 |
+
existing_csvs = list(raw_dir.glob("*.csv"))
|
| 39 |
+
if existing_csvs and not force:
|
| 40 |
+
print(f"Data already exists in {raw_dir} ({len(existing_csvs)} CSV files)")
|
| 41 |
+
print("Use force=True to re-download")
|
| 42 |
+
return raw_dir
|
| 43 |
+
|
| 44 |
+
print("Downloading CQI coffee quality dataset from Kaggle...")
|
| 45 |
+
print("Dataset: volpatto/coffee-quality-database-from-cqi")
|
| 46 |
+
|
| 47 |
+
# kagglehub downloads to its cache directory
|
| 48 |
+
# Using volpatto's dataset which has both Arabica (~1300) and Robusta (~28) samples
|
| 49 |
+
cache_path = kagglehub.dataset_download("volpatto/coffee-quality-database-from-cqi")
|
| 50 |
+
cache_path = Path(cache_path)
|
| 51 |
+
|
| 52 |
+
print(f"Downloaded to cache: {cache_path}")
|
| 53 |
+
|
| 54 |
+
# Find all CSV files in the downloaded data
|
| 55 |
+
csv_files = list(cache_path.glob("**/*.csv"))
|
| 56 |
+
if not csv_files:
|
| 57 |
+
raise RuntimeError(f"No CSV files found in downloaded data at {cache_path}")
|
| 58 |
+
|
| 59 |
+
# Copy CSV files to raw directory
|
| 60 |
+
print(f"Copying {len(csv_files)} CSV file(s) to {raw_dir}")
|
| 61 |
+
for csv_file in csv_files:
|
| 62 |
+
dest = raw_dir / csv_file.name
|
| 63 |
+
shutil.copy2(csv_file, dest)
|
| 64 |
+
print(f" - {csv_file.name}")
|
| 65 |
+
|
| 66 |
+
print(f"Data saved to {raw_dir}")
|
| 67 |
+
return raw_dir
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main() -> None:
|
| 71 |
+
"""Entry point for `uv run download`."""
|
| 72 |
+
import argparse
|
| 73 |
+
|
| 74 |
+
parser = argparse.ArgumentParser(
|
| 75 |
+
description="Download the CQI coffee quality dataset from Kaggle"
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--force",
|
| 79 |
+
action="store_true",
|
| 80 |
+
help="Re-download even if data already exists",
|
| 81 |
+
)
|
| 82 |
+
args = parser.parse_args()
|
| 83 |
+
|
| 84 |
+
download_data(force=args.force)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
main()
|
src/brewmatch/data/preprocess.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocess the CQI coffee quality dataset."""
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
from sklearn.preprocessing import StandardScaler
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Taste profile features (9 total) - using actual CSV column names
|
| 14 |
+
TASTE_FEATURES = [
|
| 15 |
+
"Aroma",
|
| 16 |
+
"Flavor",
|
| 17 |
+
"Aftertaste",
|
| 18 |
+
"Acidity",
|
| 19 |
+
"Body",
|
| 20 |
+
"Balance",
|
| 21 |
+
"Uniformity",
|
| 22 |
+
"Clean Cup",
|
| 23 |
+
"Sweetness",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
# Target column
|
| 27 |
+
TARGET_COLUMN = "Total Cup Points"
|
| 28 |
+
|
| 29 |
+
# Metadata columns to preserve
|
| 30 |
+
METADATA_COLUMNS = [
|
| 31 |
+
"Country of Origin",
|
| 32 |
+
"Processing Method",
|
| 33 |
+
"Variety",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_project_root() -> Path:
|
| 38 |
+
"""Get the project root directory (where pyproject.toml is located)."""
|
| 39 |
+
current = Path(__file__).resolve()
|
| 40 |
+
for parent in current.parents:
|
| 41 |
+
if (parent / "pyproject.toml").exists():
|
| 42 |
+
return parent
|
| 43 |
+
raise RuntimeError("Could not find project root (no pyproject.toml found)")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def normalize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
| 47 |
+
"""Normalize column names to use spaces instead of dots/underscores."""
|
| 48 |
+
# Map common column name variations
|
| 49 |
+
column_mapping = {
|
| 50 |
+
"Country.of.Origin": "Country of Origin",
|
| 51 |
+
"Processing.Method": "Processing Method",
|
| 52 |
+
"Clean.Cup": "Clean Cup",
|
| 53 |
+
"Total.Cup.Points": "Total Cup Points",
|
| 54 |
+
"Cupper.Points": "Cupper Points",
|
| 55 |
+
"Category.One.Defects": "Category One Defects",
|
| 56 |
+
"Category.Two.Defects": "Category Two Defects",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Apply mapping
|
| 60 |
+
df = df.rename(columns=column_mapping)
|
| 61 |
+
return df
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_raw_data() -> pd.DataFrame:
|
| 65 |
+
"""
|
| 66 |
+
Load raw CSV data from data/raw/.
|
| 67 |
+
|
| 68 |
+
Prefers merged_data_cleaned.csv if available (has most samples).
|
| 69 |
+
Falls back to combining all CSV files.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Combined DataFrame from raw CSV files.
|
| 73 |
+
|
| 74 |
+
Raises:
|
| 75 |
+
FileNotFoundError: If no CSV files found in raw directory.
|
| 76 |
+
"""
|
| 77 |
+
project_root = get_project_root()
|
| 78 |
+
raw_dir = project_root / "data" / "raw"
|
| 79 |
+
|
| 80 |
+
csv_files = list(raw_dir.glob("*.csv"))
|
| 81 |
+
if not csv_files:
|
| 82 |
+
raise FileNotFoundError(
|
| 83 |
+
f"No CSV files found in {raw_dir}. Run `uv run download` first."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Prefer the merged dataset if available (has most samples)
|
| 87 |
+
merged_file = raw_dir / "merged_data_cleaned.csv"
|
| 88 |
+
if merged_file.exists():
|
| 89 |
+
print(f"Loading merged dataset: {merged_file.name}")
|
| 90 |
+
df = pd.read_csv(merged_file)
|
| 91 |
+
df = normalize_column_names(df)
|
| 92 |
+
return df
|
| 93 |
+
|
| 94 |
+
print(f"Found {len(csv_files)} CSV file(s) in {raw_dir}")
|
| 95 |
+
|
| 96 |
+
dfs = []
|
| 97 |
+
for csv_file in csv_files:
|
| 98 |
+
print(f" Loading {csv_file.name}...")
|
| 99 |
+
df = pd.read_csv(csv_file)
|
| 100 |
+
df = normalize_column_names(df)
|
| 101 |
+
dfs.append(df)
|
| 102 |
+
|
| 103 |
+
if len(dfs) == 1:
|
| 104 |
+
return dfs[0]
|
| 105 |
+
|
| 106 |
+
# Combine multiple CSVs
|
| 107 |
+
combined = pd.concat(dfs, ignore_index=True)
|
| 108 |
+
print(f"Combined {len(dfs)} files into {len(combined)} rows")
|
| 109 |
+
return combined
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def preprocess_data(
|
| 113 |
+
test_size: float = 0.2,
|
| 114 |
+
random_state: int = 42,
|
| 115 |
+
) -> dict[str, Any]:
|
| 116 |
+
"""
|
| 117 |
+
Preprocess the coffee quality dataset.
|
| 118 |
+
|
| 119 |
+
Steps:
|
| 120 |
+
1. Load raw data
|
| 121 |
+
2. Select relevant columns
|
| 122 |
+
3. Drop rows with missing quality scores
|
| 123 |
+
4. Normalize numeric features
|
| 124 |
+
5. Split into train/test sets
|
| 125 |
+
6. Save processed data and scaler
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
test_size: Fraction of data for test set (default 0.2).
|
| 129 |
+
random_state: Random seed for reproducibility.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Dictionary with paths to saved files:
|
| 133 |
+
- train_path: Path to training parquet
|
| 134 |
+
- test_path: Path to test parquet
|
| 135 |
+
- scaler_path: Path to scaler pickle
|
| 136 |
+
"""
|
| 137 |
+
project_root = get_project_root()
|
| 138 |
+
processed_dir = project_root / "data" / "processed"
|
| 139 |
+
processed_dir.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
# Load raw data
|
| 142 |
+
print("Loading raw data...")
|
| 143 |
+
df = load_raw_data()
|
| 144 |
+
print(f"Loaded {len(df)} rows")
|
| 145 |
+
|
| 146 |
+
# Check for required columns
|
| 147 |
+
required_cols = TASTE_FEATURES + [TARGET_COLUMN]
|
| 148 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
| 149 |
+
if missing_cols:
|
| 150 |
+
raise ValueError(f"Missing required columns: {missing_cols}")
|
| 151 |
+
|
| 152 |
+
# Select columns to keep
|
| 153 |
+
cols_to_keep = TASTE_FEATURES + [TARGET_COLUMN]
|
| 154 |
+
for col in METADATA_COLUMNS:
|
| 155 |
+
if col in df.columns:
|
| 156 |
+
cols_to_keep.append(col)
|
| 157 |
+
|
| 158 |
+
df = df[cols_to_keep].copy()
|
| 159 |
+
print(f"Selected {len(cols_to_keep)} columns")
|
| 160 |
+
|
| 161 |
+
# Report missing values before dropping
|
| 162 |
+
quality_cols = TASTE_FEATURES + [TARGET_COLUMN]
|
| 163 |
+
missing_before = df[quality_cols].isna().sum()
|
| 164 |
+
rows_with_missing = df[quality_cols].isna().any(axis=1).sum()
|
| 165 |
+
print(f"Rows with missing quality scores: {rows_with_missing}")
|
| 166 |
+
|
| 167 |
+
if missing_before.sum() > 0:
|
| 168 |
+
print("Missing values per column:")
|
| 169 |
+
for col, count in missing_before[missing_before > 0].items():
|
| 170 |
+
print(f" {col}: {count}")
|
| 171 |
+
|
| 172 |
+
# Drop rows with missing quality scores
|
| 173 |
+
df_clean = df.dropna(subset=quality_cols)
|
| 174 |
+
dropped = len(df) - len(df_clean)
|
| 175 |
+
print(f"Dropped {dropped} rows with missing quality scores")
|
| 176 |
+
print(f"Remaining: {len(df_clean)} rows")
|
| 177 |
+
|
| 178 |
+
if len(df_clean) == 0:
|
| 179 |
+
raise ValueError("No data remaining after dropping missing values")
|
| 180 |
+
|
| 181 |
+
# Split into features and target
|
| 182 |
+
X = df_clean[TASTE_FEATURES].values
|
| 183 |
+
y = df_clean[TARGET_COLUMN].values
|
| 184 |
+
metadata = df_clean[[c for c in METADATA_COLUMNS if c in df_clean.columns]]
|
| 185 |
+
|
| 186 |
+
# Train/test split
|
| 187 |
+
X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(
|
| 188 |
+
X,
|
| 189 |
+
y,
|
| 190 |
+
df_clean.index.values,
|
| 191 |
+
test_size=test_size,
|
| 192 |
+
random_state=random_state,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
print(f"Train set: {len(X_train)} samples")
|
| 196 |
+
print(f"Test set: {len(X_test)} samples")
|
| 197 |
+
|
| 198 |
+
# Fit scaler on training data only
|
| 199 |
+
scaler = StandardScaler()
|
| 200 |
+
X_train_scaled = scaler.fit_transform(X_train)
|
| 201 |
+
X_test_scaled = scaler.transform(X_test)
|
| 202 |
+
|
| 203 |
+
print(f"Features normalized (mean={X_train_scaled.mean():.4f}, std={X_train_scaled.std():.4f})")
|
| 204 |
+
|
| 205 |
+
# Create DataFrames with scaled features
|
| 206 |
+
train_df = pd.DataFrame(X_train_scaled, columns=TASTE_FEATURES)
|
| 207 |
+
train_df[TARGET_COLUMN] = y_train
|
| 208 |
+
|
| 209 |
+
# Add metadata using original indices
|
| 210 |
+
for col in metadata.columns:
|
| 211 |
+
train_df[col] = metadata.loc[idx_train, col].values
|
| 212 |
+
|
| 213 |
+
test_df = pd.DataFrame(X_test_scaled, columns=TASTE_FEATURES)
|
| 214 |
+
test_df[TARGET_COLUMN] = y_test
|
| 215 |
+
|
| 216 |
+
for col in metadata.columns:
|
| 217 |
+
test_df[col] = metadata.loc[idx_test, col].values
|
| 218 |
+
|
| 219 |
+
# Save processed data
|
| 220 |
+
train_path = processed_dir / "train.parquet"
|
| 221 |
+
test_path = processed_dir / "test.parquet"
|
| 222 |
+
scaler_path = processed_dir / "scaler.pkl"
|
| 223 |
+
|
| 224 |
+
train_df.to_parquet(train_path, index=False)
|
| 225 |
+
test_df.to_parquet(test_path, index=False)
|
| 226 |
+
|
| 227 |
+
with open(scaler_path, "wb") as f:
|
| 228 |
+
pickle.dump(scaler, f)
|
| 229 |
+
|
| 230 |
+
print(f"\nSaved processed data:")
|
| 231 |
+
print(f" Train: {train_path}")
|
| 232 |
+
print(f" Test: {test_path}")
|
| 233 |
+
print(f" Scaler: {scaler_path}")
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
"train_path": train_path,
|
| 237 |
+
"test_path": test_path,
|
| 238 |
+
"scaler_path": scaler_path,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def load_processed_data() -> dict[str, Any]:
|
| 243 |
+
"""
|
| 244 |
+
Load preprocessed data from data/processed/.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Dictionary containing:
|
| 248 |
+
- train_df: Training DataFrame
|
| 249 |
+
- test_df: Test DataFrame
|
| 250 |
+
- scaler: Fitted StandardScaler
|
| 251 |
+
- taste_features: List of taste feature column names
|
| 252 |
+
- target_column: Name of target column
|
| 253 |
+
|
| 254 |
+
Raises:
|
| 255 |
+
FileNotFoundError: If processed data doesn't exist.
|
| 256 |
+
"""
|
| 257 |
+
project_root = get_project_root()
|
| 258 |
+
processed_dir = project_root / "data" / "processed"
|
| 259 |
+
|
| 260 |
+
train_path = processed_dir / "train.parquet"
|
| 261 |
+
test_path = processed_dir / "test.parquet"
|
| 262 |
+
scaler_path = processed_dir / "scaler.pkl"
|
| 263 |
+
|
| 264 |
+
# Check all files exist
|
| 265 |
+
for path in [train_path, test_path, scaler_path]:
|
| 266 |
+
if not path.exists():
|
| 267 |
+
raise FileNotFoundError(
|
| 268 |
+
f"Processed data not found: {path}. Run `uv run preprocess` first."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
train_df = pd.read_parquet(train_path)
|
| 272 |
+
test_df = pd.read_parquet(test_path)
|
| 273 |
+
|
| 274 |
+
with open(scaler_path, "rb") as f:
|
| 275 |
+
scaler = pickle.load(f)
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
"train_df": train_df,
|
| 279 |
+
"test_df": test_df,
|
| 280 |
+
"scaler": scaler,
|
| 281 |
+
"taste_features": TASTE_FEATURES,
|
| 282 |
+
"target_column": TARGET_COLUMN,
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def main() -> None:
|
| 287 |
+
"""Entry point for `uv run preprocess`."""
|
| 288 |
+
import argparse
|
| 289 |
+
|
| 290 |
+
parser = argparse.ArgumentParser(
|
| 291 |
+
description="Preprocess the CQI coffee quality dataset"
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--test-size",
|
| 295 |
+
type=float,
|
| 296 |
+
default=0.2,
|
| 297 |
+
help="Fraction of data for test set (default: 0.2)",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--seed",
|
| 301 |
+
type=int,
|
| 302 |
+
default=42,
|
| 303 |
+
help="Random seed for reproducibility (default: 42)",
|
| 304 |
+
)
|
| 305 |
+
args = parser.parse_args()
|
| 306 |
+
|
| 307 |
+
preprocess_data(test_size=args.test_size, random_state=args.seed)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
main()
|
src/brewmatch/device.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Device detection and selection utilities.
|
| 2 |
+
|
| 3 |
+
Provides automatic device selection with fallback:
|
| 4 |
+
CUDA (if available) > MPS (Apple Silicon) > CPU
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_device(preferred: str | None = None) -> torch.device:
|
| 11 |
+
"""
|
| 12 |
+
Get the best available device for PyTorch operations.
|
| 13 |
+
|
| 14 |
+
Priority: CUDA > MPS > CPU (unless preferred is specified)
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
preferred: Optional preferred device ("cuda", "mps", "cpu").
|
| 18 |
+
If specified and available, uses that device.
|
| 19 |
+
If not available, falls back to best available.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.device for the selected device.
|
| 23 |
+
"""
|
| 24 |
+
if preferred:
|
| 25 |
+
preferred = preferred.lower()
|
| 26 |
+
if preferred == "cuda" and torch.cuda.is_available():
|
| 27 |
+
return torch.device("cuda")
|
| 28 |
+
elif preferred == "mps" and torch.backends.mps.is_available():
|
| 29 |
+
return torch.device("mps")
|
| 30 |
+
elif preferred == "cpu":
|
| 31 |
+
return torch.device("cpu")
|
| 32 |
+
# Fall through to auto-detection if preferred not available
|
| 33 |
+
|
| 34 |
+
# Auto-detect best available
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
return torch.device("cuda")
|
| 37 |
+
elif torch.backends.mps.is_available():
|
| 38 |
+
return torch.device("mps")
|
| 39 |
+
else:
|
| 40 |
+
return torch.device("cpu")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_device_info() -> dict[str, bool | str]:
|
| 44 |
+
"""
|
| 45 |
+
Get information about available devices.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Dictionary with device availability and selected device.
|
| 49 |
+
"""
|
| 50 |
+
info = {
|
| 51 |
+
"cuda_available": torch.cuda.is_available(),
|
| 52 |
+
"mps_available": torch.backends.mps.is_available(),
|
| 53 |
+
"selected": str(get_device()),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
if torch.cuda.is_available():
|
| 57 |
+
info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
| 58 |
+
info["cuda_device_count"] = torch.cuda.device_count()
|
| 59 |
+
|
| 60 |
+
return info
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def print_device_info() -> None:
|
| 64 |
+
"""Print device information to console."""
|
| 65 |
+
info = get_device_info()
|
| 66 |
+
print(f"Device: {info['selected']}")
|
| 67 |
+
|
| 68 |
+
if info["cuda_available"]:
|
| 69 |
+
print(f" CUDA: {info.get('cuda_device_name', 'Unknown')} "
|
| 70 |
+
f"(x{info.get('cuda_device_count', 1)})")
|
| 71 |
+
elif info["mps_available"]:
|
| 72 |
+
print(" MPS: Apple Silicon GPU")
|
| 73 |
+
else:
|
| 74 |
+
print(" CPU: No GPU acceleration available")
|
src/brewmatch/evaluate.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation script for BrewMatch models."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
from brewmatch.config import (
|
| 12 |
+
CHECKPOINTS_DIR,
|
| 13 |
+
K_VALUES,
|
| 14 |
+
TASTE_FEATURES,
|
| 15 |
+
)
|
| 16 |
+
from brewmatch.data import load_processed_data
|
| 17 |
+
from brewmatch.models import (
|
| 18 |
+
NaiveBaselineRecommender,
|
| 19 |
+
ClassicalMLRecommender,
|
| 20 |
+
NeuralRecommender,
|
| 21 |
+
)
|
| 22 |
+
from brewmatch.evaluation import evaluate_model, generate_error_report
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_models() -> dict[str, Any]:
|
| 26 |
+
"""Load all trained models."""
|
| 27 |
+
models = {}
|
| 28 |
+
|
| 29 |
+
baseline_path = CHECKPOINTS_DIR / "baseline.pkl"
|
| 30 |
+
if baseline_path.exists():
|
| 31 |
+
models["baseline"] = NaiveBaselineRecommender.load(baseline_path)
|
| 32 |
+
print(f"Loaded baseline model from {baseline_path}")
|
| 33 |
+
|
| 34 |
+
classical_path = CHECKPOINTS_DIR / "classical.pkl"
|
| 35 |
+
if classical_path.exists():
|
| 36 |
+
models["classical"] = ClassicalMLRecommender.load(classical_path)
|
| 37 |
+
print(f"Loaded classical model from {classical_path}")
|
| 38 |
+
|
| 39 |
+
neural_path = CHECKPOINTS_DIR / "neural.pt"
|
| 40 |
+
if neural_path.exists():
|
| 41 |
+
models["neural"] = NeuralRecommender.load(neural_path)
|
| 42 |
+
print(f"Loaded neural model from {neural_path}")
|
| 43 |
+
|
| 44 |
+
return models
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compare_models(results: dict[str, dict[str, Any]]) -> None:
|
| 48 |
+
"""Print comparison table of all models."""
|
| 49 |
+
print("\n" + "=" * 60)
|
| 50 |
+
print("MODEL COMPARISON")
|
| 51 |
+
print("=" * 60)
|
| 52 |
+
|
| 53 |
+
# Flatten nested dicts (precision@k, recall@k, etc.)
|
| 54 |
+
flat_results = {}
|
| 55 |
+
for model_name, metrics in results.items():
|
| 56 |
+
flat_metrics = {}
|
| 57 |
+
for key, value in metrics.items():
|
| 58 |
+
if isinstance(value, dict):
|
| 59 |
+
for k, v in value.items():
|
| 60 |
+
flat_metrics[f"{key.replace('@k', '')}@{k}"] = v
|
| 61 |
+
elif isinstance(value, (int, float)) and not isinstance(value, bool):
|
| 62 |
+
flat_metrics[key] = value
|
| 63 |
+
flat_results[model_name] = flat_metrics
|
| 64 |
+
|
| 65 |
+
if not flat_results:
|
| 66 |
+
print("No results to compare.")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
# Get all metric keys
|
| 70 |
+
all_keys = set()
|
| 71 |
+
for metrics in flat_results.values():
|
| 72 |
+
all_keys.update(metrics.keys())
|
| 73 |
+
all_keys = sorted(all_keys)
|
| 74 |
+
|
| 75 |
+
# Print table
|
| 76 |
+
header = f"{'Model':<12}" + "".join(f"{k:>12}" for k in all_keys)
|
| 77 |
+
print(header)
|
| 78 |
+
print("-" * len(header))
|
| 79 |
+
|
| 80 |
+
for model_name, metrics in flat_results.items():
|
| 81 |
+
row = f"{model_name:<12}"
|
| 82 |
+
for key in all_keys:
|
| 83 |
+
val = metrics.get(key, float("nan"))
|
| 84 |
+
if isinstance(val, float):
|
| 85 |
+
row += f"{val:>12.4f}"
|
| 86 |
+
else:
|
| 87 |
+
row += f"{val:>12}"
|
| 88 |
+
print(row)
|
| 89 |
+
|
| 90 |
+
# Find best model for primary metrics
|
| 91 |
+
print("\nBest model per metric:")
|
| 92 |
+
for key in ["precision@5", "ndcg@5", "recall@5"]:
|
| 93 |
+
if key in all_keys:
|
| 94 |
+
best_model = max(
|
| 95 |
+
flat_results.keys(),
|
| 96 |
+
key=lambda m: flat_results[m].get(key, 0)
|
| 97 |
+
)
|
| 98 |
+
best_value = flat_results[best_model].get(key, 0)
|
| 99 |
+
print(f" - {key}: {best_model} ({best_value:.4f})")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
"""Main evaluation entry point."""
|
| 104 |
+
parser = argparse.ArgumentParser(description="Evaluate BrewMatch models")
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--models",
|
| 107 |
+
nargs="+",
|
| 108 |
+
choices=["baseline", "classical", "neural", "all"],
|
| 109 |
+
default=["all"],
|
| 110 |
+
help="Which models to evaluate",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--error-analysis",
|
| 114 |
+
action="store_true",
|
| 115 |
+
help="Generate detailed error analysis",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--output",
|
| 119 |
+
type=str,
|
| 120 |
+
default=None,
|
| 121 |
+
help="Save results to JSON file",
|
| 122 |
+
)
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
|
| 125 |
+
# Load data
|
| 126 |
+
print("Loading test data...")
|
| 127 |
+
data = load_processed_data()
|
| 128 |
+
test_df = data["test_df"]
|
| 129 |
+
print(f"Test samples: {len(test_df)}")
|
| 130 |
+
print()
|
| 131 |
+
|
| 132 |
+
# Load models
|
| 133 |
+
print("Loading models...")
|
| 134 |
+
all_models = load_models()
|
| 135 |
+
|
| 136 |
+
if "all" in args.models:
|
| 137 |
+
models_to_eval = all_models
|
| 138 |
+
else:
|
| 139 |
+
models_to_eval = {k: v for k, v in all_models.items() if k in args.models}
|
| 140 |
+
|
| 141 |
+
if not models_to_eval:
|
| 142 |
+
print("No models found to evaluate!")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
print(f"\nEvaluating: {list(models_to_eval.keys())}")
|
| 146 |
+
print()
|
| 147 |
+
|
| 148 |
+
# Prepare test data dict for evaluation
|
| 149 |
+
test_data = {
|
| 150 |
+
"X": test_df[TASTE_FEATURES].values,
|
| 151 |
+
"metadata": test_df,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# Evaluate each model
|
| 155 |
+
results = {}
|
| 156 |
+
for name, model in models_to_eval.items():
|
| 157 |
+
print(f"\n{'=' * 40}")
|
| 158 |
+
print(f"Evaluating: {name.upper()}")
|
| 159 |
+
print("=" * 40)
|
| 160 |
+
|
| 161 |
+
metrics = evaluate_model(
|
| 162 |
+
model=model,
|
| 163 |
+
test_data=test_data,
|
| 164 |
+
k_values=K_VALUES,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
results[name] = metrics
|
| 168 |
+
|
| 169 |
+
print(f"\nResults for {name}:")
|
| 170 |
+
for metric, value in metrics.items():
|
| 171 |
+
if isinstance(value, dict):
|
| 172 |
+
for k, v in value.items():
|
| 173 |
+
print(f" {metric}@{k}: {v:.4f}")
|
| 174 |
+
elif isinstance(value, float):
|
| 175 |
+
print(f" {metric}: {value:.4f}")
|
| 176 |
+
else:
|
| 177 |
+
print(f" {metric}: {value}")
|
| 178 |
+
|
| 179 |
+
# Error analysis
|
| 180 |
+
if args.error_analysis:
|
| 181 |
+
print(f"\nError Analysis for {name}:")
|
| 182 |
+
report = generate_error_report(
|
| 183 |
+
model=model,
|
| 184 |
+
test_data=test_data,
|
| 185 |
+
)
|
| 186 |
+
print(f" Error rate: {report.error_rate:.1%}")
|
| 187 |
+
print(f" Total errors: {report.total_errors}/{report.total_queries}")
|
| 188 |
+
print("\n Worst errors:")
|
| 189 |
+
for i, err in enumerate(report.worst_errors[:5], 1):
|
| 190 |
+
print(f" {i}. Query {err.query_idx}: magnitude={err.error_magnitude:.3f}")
|
| 191 |
+
if "_root_cause" in err.query_metadata:
|
| 192 |
+
print(f" Root cause: {err.query_metadata['_root_cause']}")
|
| 193 |
+
print("\n Patterns:")
|
| 194 |
+
for pattern in report.patterns[:3]:
|
| 195 |
+
print(f" - {pattern.description} (freq: {pattern.frequency})")
|
| 196 |
+
print("\n Mitigations:")
|
| 197 |
+
for mitigation in report.mitigations[:3]:
|
| 198 |
+
print(f" - {mitigation[:80]}...")
|
| 199 |
+
|
| 200 |
+
# Compare models
|
| 201 |
+
if len(results) > 1:
|
| 202 |
+
compare_models(results)
|
| 203 |
+
|
| 204 |
+
# Save results
|
| 205 |
+
if args.output:
|
| 206 |
+
output_path = Path(args.output)
|
| 207 |
+
# Convert results to JSON-serializable format
|
| 208 |
+
json_results = {}
|
| 209 |
+
for model_name, metrics in results.items():
|
| 210 |
+
json_results[model_name] = {}
|
| 211 |
+
for key, value in metrics.items():
|
| 212 |
+
if isinstance(value, dict):
|
| 213 |
+
json_results[model_name][key] = {str(k): v for k, v in value.items()}
|
| 214 |
+
else:
|
| 215 |
+
json_results[model_name][key] = value
|
| 216 |
+
with open(output_path, "w") as f:
|
| 217 |
+
json.dump(json_results, f, indent=2)
|
| 218 |
+
print(f"\nResults saved to {output_path}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
main()
|
src/brewmatch/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation metrics and analysis modules."""
|
| 2 |
+
|
| 3 |
+
from .metrics import (
|
| 4 |
+
precision_at_k,
|
| 5 |
+
recall_at_k,
|
| 6 |
+
ndcg_at_k,
|
| 7 |
+
mean_squared_error,
|
| 8 |
+
mean_absolute_error,
|
| 9 |
+
evaluate_model,
|
| 10 |
+
)
|
| 11 |
+
from .error_analysis import (
|
| 12 |
+
analyze_errors,
|
| 13 |
+
identify_error_patterns,
|
| 14 |
+
generate_error_report,
|
| 15 |
+
PredictionError,
|
| 16 |
+
ErrorPattern,
|
| 17 |
+
ErrorReport,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"precision_at_k",
|
| 22 |
+
"recall_at_k",
|
| 23 |
+
"ndcg_at_k",
|
| 24 |
+
"mean_squared_error",
|
| 25 |
+
"mean_absolute_error",
|
| 26 |
+
"evaluate_model",
|
| 27 |
+
"analyze_errors",
|
| 28 |
+
"identify_error_patterns",
|
| 29 |
+
"generate_error_report",
|
| 30 |
+
"PredictionError",
|
| 31 |
+
"ErrorPattern",
|
| 32 |
+
"ErrorReport",
|
| 33 |
+
]
|
src/brewmatch/evaluation/error_analysis.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Error analysis module for the coffee recommendation system.
|
| 2 |
+
|
| 3 |
+
This module provides tools for analyzing model errors:
|
| 4 |
+
- Finding worst predictions
|
| 5 |
+
- Identifying error patterns by origin, processing method, etc.
|
| 6 |
+
- Generating comprehensive error reports with mitigation strategies
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any, Protocol, runtime_checkable
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@runtime_checkable
|
| 17 |
+
class Recommender(Protocol):
|
| 18 |
+
"""Protocol for recommender models used in error analysis."""
|
| 19 |
+
|
| 20 |
+
def recommend(self, preferences: np.ndarray, k: int = 5) -> list[dict[str, Any]]:
|
| 21 |
+
"""Recommend coffees matching user taste preferences."""
|
| 22 |
+
...
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class PredictionError:
|
| 27 |
+
"""Represents a single misprediction for analysis.
|
| 28 |
+
|
| 29 |
+
Attributes:
|
| 30 |
+
query_idx: Index of the query coffee in the test set.
|
| 31 |
+
query_preferences: The query taste profile used.
|
| 32 |
+
query_metadata: Metadata of the query coffee.
|
| 33 |
+
recommended_idx: Index of the top recommended coffee.
|
| 34 |
+
recommended_metadata: Metadata of the recommended coffee.
|
| 35 |
+
recommended_profile: Taste profile of the recommended coffee.
|
| 36 |
+
expected_indices: Set of indices that would have been correct.
|
| 37 |
+
error_magnitude: Quantified error (e.g., Euclidean distance or rank loss).
|
| 38 |
+
rank_of_first_relevant: Position of first relevant item in recommendations.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
query_idx: int
|
| 42 |
+
query_preferences: np.ndarray
|
| 43 |
+
query_metadata: dict[str, Any]
|
| 44 |
+
recommended_idx: int
|
| 45 |
+
recommended_metadata: dict[str, Any]
|
| 46 |
+
recommended_profile: np.ndarray
|
| 47 |
+
expected_indices: set[int]
|
| 48 |
+
error_magnitude: float
|
| 49 |
+
rank_of_first_relevant: int | None = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class ErrorPattern:
|
| 54 |
+
"""Represents a common error pattern identified in the analysis.
|
| 55 |
+
|
| 56 |
+
Attributes:
|
| 57 |
+
pattern_type: Category of the pattern (e.g., 'origin', 'processing', 'profile').
|
| 58 |
+
description: Human-readable description of the pattern.
|
| 59 |
+
frequency: Number of errors exhibiting this pattern.
|
| 60 |
+
affected_queries: List of query indices affected by this pattern.
|
| 61 |
+
severity: Average error magnitude for this pattern.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
pattern_type: str
|
| 65 |
+
description: str
|
| 66 |
+
frequency: int
|
| 67 |
+
affected_queries: list[int] = field(default_factory=list)
|
| 68 |
+
severity: float = 0.0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class ErrorReport:
|
| 73 |
+
"""Comprehensive error analysis report.
|
| 74 |
+
|
| 75 |
+
Attributes:
|
| 76 |
+
total_queries: Total number of queries evaluated.
|
| 77 |
+
total_errors: Number of queries where top recommendation was incorrect.
|
| 78 |
+
error_rate: Proportion of queries with errors.
|
| 79 |
+
worst_errors: List of the worst prediction errors.
|
| 80 |
+
patterns: Identified error patterns.
|
| 81 |
+
mitigations: Suggested mitigation strategies.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
total_queries: int
|
| 85 |
+
total_errors: int
|
| 86 |
+
error_rate: float
|
| 87 |
+
worst_errors: list[PredictionError]
|
| 88 |
+
patterns: list[ErrorPattern]
|
| 89 |
+
mitigations: list[str]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _compute_euclidean_distance(a: np.ndarray, b: np.ndarray) -> float:
|
| 93 |
+
"""Compute Euclidean distance between two arrays."""
|
| 94 |
+
return float(np.sqrt(np.sum((a - b) ** 2)))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _compute_similarity(profile_a: np.ndarray, profile_b: np.ndarray) -> float:
|
| 98 |
+
"""Compute cosine similarity between two taste profiles."""
|
| 99 |
+
norm_a = np.linalg.norm(profile_a)
|
| 100 |
+
norm_b = np.linalg.norm(profile_b)
|
| 101 |
+
if norm_a == 0 or norm_b == 0:
|
| 102 |
+
return 0.0
|
| 103 |
+
return float(np.dot(profile_a, profile_b) / (norm_a * norm_b))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _find_relevant_items(
|
| 107 |
+
query_metadata: dict[str, Any],
|
| 108 |
+
query_profile: np.ndarray,
|
| 109 |
+
all_metadata: list[dict[str, Any]],
|
| 110 |
+
all_profiles: np.ndarray,
|
| 111 |
+
query_idx: int,
|
| 112 |
+
similarity_threshold: float = 0.95,
|
| 113 |
+
) -> set[int]:
|
| 114 |
+
"""Identify relevant items for a query coffee.
|
| 115 |
+
|
| 116 |
+
An item is considered relevant if:
|
| 117 |
+
- It shares the same country AND processing method, OR
|
| 118 |
+
- It has high taste profile similarity (>= threshold)
|
| 119 |
+
"""
|
| 120 |
+
relevant = set()
|
| 121 |
+
query_country = query_metadata.get("Country of Origin", "")
|
| 122 |
+
query_processing = query_metadata.get("Processing Method", "")
|
| 123 |
+
|
| 124 |
+
for i, meta in enumerate(all_metadata):
|
| 125 |
+
if i == query_idx:
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
same_country = meta.get("Country of Origin", "") == query_country
|
| 129 |
+
same_processing = meta.get("Processing Method", "") == query_processing
|
| 130 |
+
|
| 131 |
+
if same_country and same_processing and query_country and query_processing:
|
| 132 |
+
relevant.add(i)
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
similarity = _compute_similarity(query_profile, all_profiles[i])
|
| 136 |
+
if similarity >= similarity_threshold:
|
| 137 |
+
relevant.add(i)
|
| 138 |
+
|
| 139 |
+
return relevant
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def analyze_errors(
|
| 143 |
+
model: Recommender,
|
| 144 |
+
test_data: dict[str, Any],
|
| 145 |
+
n_errors: int = 5,
|
| 146 |
+
) -> list[PredictionError]:
|
| 147 |
+
"""Find the worst predictions made by the model.
|
| 148 |
+
|
| 149 |
+
Analyzes each test coffee as a query and identifies cases where the
|
| 150 |
+
model's top recommendations were most incorrect.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
model: A fitted recommender model with a recommend() method.
|
| 154 |
+
test_data: Dictionary containing:
|
| 155 |
+
- 'X': Feature matrix of shape (n_samples, 9) with taste profiles.
|
| 156 |
+
- 'metadata': List of metadata dicts or DataFrame.
|
| 157 |
+
n_errors: Number of worst errors to return.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
List of PredictionError objects for the n_errors worst predictions,
|
| 161 |
+
sorted by error magnitude (descending).
|
| 162 |
+
|
| 163 |
+
Example:
|
| 164 |
+
>>> errors = analyze_errors(model, test_data, n_errors=5)
|
| 165 |
+
>>> for err in errors:
|
| 166 |
+
... print(f"Query {err.query_idx}: magnitude={err.error_magnitude:.3f}")
|
| 167 |
+
"""
|
| 168 |
+
X = np.asarray(test_data["X"], dtype=np.float32)
|
| 169 |
+
metadata_raw = test_data["metadata"]
|
| 170 |
+
|
| 171 |
+
if hasattr(metadata_raw, "to_dict"):
|
| 172 |
+
all_metadata = metadata_raw.to_dict("records")
|
| 173 |
+
else:
|
| 174 |
+
all_metadata = list(metadata_raw)
|
| 175 |
+
|
| 176 |
+
n_samples = len(X)
|
| 177 |
+
errors: list[PredictionError] = []
|
| 178 |
+
taste_features = [
|
| 179 |
+
"Aroma", "Flavor", "Aftertaste", "Acidity", "Body",
|
| 180 |
+
"Balance", "Uniformity", "Clean Cup", "Sweetness"
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
for query_idx in range(n_samples):
|
| 184 |
+
query_profile = X[query_idx]
|
| 185 |
+
query_metadata = all_metadata[query_idx]
|
| 186 |
+
|
| 187 |
+
relevant = _find_relevant_items(
|
| 188 |
+
query_metadata=query_metadata,
|
| 189 |
+
query_profile=query_profile,
|
| 190 |
+
all_metadata=all_metadata,
|
| 191 |
+
all_profiles=X,
|
| 192 |
+
query_idx=query_idx,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if not relevant:
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
recommendations = model.recommend(query_profile, k=max(10, n_errors))
|
| 199 |
+
recommended_indices = [rec["index"] for rec in recommendations]
|
| 200 |
+
|
| 201 |
+
if not recommendations:
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
top_rec = recommendations[0]
|
| 205 |
+
top_idx = top_rec["index"]
|
| 206 |
+
|
| 207 |
+
# Check if top recommendation is relevant
|
| 208 |
+
if top_idx in relevant:
|
| 209 |
+
continue
|
| 210 |
+
|
| 211 |
+
# This is an error - compute error magnitude
|
| 212 |
+
rec_profile = np.array(
|
| 213 |
+
[top_rec["taste_profile"][f] for f in taste_features],
|
| 214 |
+
dtype=np.float32
|
| 215 |
+
)
|
| 216 |
+
error_magnitude = _compute_euclidean_distance(query_profile, rec_profile)
|
| 217 |
+
|
| 218 |
+
# Find rank of first relevant item
|
| 219 |
+
rank_of_first_relevant = None
|
| 220 |
+
for rank, idx in enumerate(recommended_indices):
|
| 221 |
+
if idx in relevant:
|
| 222 |
+
rank_of_first_relevant = rank + 1 # 1-indexed
|
| 223 |
+
break
|
| 224 |
+
|
| 225 |
+
error = PredictionError(
|
| 226 |
+
query_idx=query_idx,
|
| 227 |
+
query_preferences=query_profile.copy(),
|
| 228 |
+
query_metadata=query_metadata,
|
| 229 |
+
recommended_idx=top_idx,
|
| 230 |
+
recommended_metadata=all_metadata[top_idx],
|
| 231 |
+
recommended_profile=rec_profile,
|
| 232 |
+
expected_indices=relevant,
|
| 233 |
+
error_magnitude=error_magnitude,
|
| 234 |
+
rank_of_first_relevant=rank_of_first_relevant,
|
| 235 |
+
)
|
| 236 |
+
errors.append(error)
|
| 237 |
+
|
| 238 |
+
# Sort by error magnitude (descending) and return top n
|
| 239 |
+
errors.sort(key=lambda e: e.error_magnitude, reverse=True)
|
| 240 |
+
return errors[:n_errors]
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def identify_error_patterns(errors: list[PredictionError]) -> list[ErrorPattern]:
|
| 244 |
+
"""Analyze a list of errors to identify common failure patterns.
|
| 245 |
+
|
| 246 |
+
Looks for patterns such as:
|
| 247 |
+
- Failures on specific origins (countries)
|
| 248 |
+
- Failures on specific processing methods
|
| 249 |
+
- Failures on certain taste profile characteristics
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
errors: List of PredictionError objects to analyze.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
List of ErrorPattern objects describing common failure modes,
|
| 256 |
+
sorted by frequency (descending).
|
| 257 |
+
|
| 258 |
+
Example:
|
| 259 |
+
>>> errors = analyze_errors(model, test_data, n_errors=20)
|
| 260 |
+
>>> patterns = identify_error_patterns(errors)
|
| 261 |
+
>>> for p in patterns:
|
| 262 |
+
... print(f"{p.pattern_type}: {p.description} ({p.frequency} occurrences)")
|
| 263 |
+
"""
|
| 264 |
+
if not errors:
|
| 265 |
+
return []
|
| 266 |
+
|
| 267 |
+
patterns: list[ErrorPattern] = []
|
| 268 |
+
|
| 269 |
+
# Pattern 1: Failures by query origin
|
| 270 |
+
origin_counter: Counter[str] = Counter()
|
| 271 |
+
origin_errors: dict[str, list[int]] = {}
|
| 272 |
+
origin_severity: dict[str, list[float]] = {}
|
| 273 |
+
|
| 274 |
+
for err in errors:
|
| 275 |
+
origin = err.query_metadata.get("Country of Origin", "Unknown")
|
| 276 |
+
origin_counter[origin] += 1
|
| 277 |
+
origin_errors.setdefault(origin, []).append(err.query_idx)
|
| 278 |
+
origin_severity.setdefault(origin, []).append(err.error_magnitude)
|
| 279 |
+
|
| 280 |
+
for origin, count in origin_counter.most_common():
|
| 281 |
+
if count >= 2: # Only report patterns with multiple occurrences
|
| 282 |
+
avg_severity = np.mean(origin_severity[origin])
|
| 283 |
+
patterns.append(ErrorPattern(
|
| 284 |
+
pattern_type="origin",
|
| 285 |
+
description=f"Model fails frequently on coffees from {origin}",
|
| 286 |
+
frequency=count,
|
| 287 |
+
affected_queries=origin_errors[origin],
|
| 288 |
+
severity=float(avg_severity),
|
| 289 |
+
))
|
| 290 |
+
|
| 291 |
+
# Pattern 2: Failures by query processing method
|
| 292 |
+
processing_counter: Counter[str] = Counter()
|
| 293 |
+
processing_errors: dict[str, list[int]] = {}
|
| 294 |
+
processing_severity: dict[str, list[float]] = {}
|
| 295 |
+
|
| 296 |
+
for err in errors:
|
| 297 |
+
method = err.query_metadata.get("Processing Method", "Unknown")
|
| 298 |
+
processing_counter[method] += 1
|
| 299 |
+
processing_errors.setdefault(method, []).append(err.query_idx)
|
| 300 |
+
processing_severity.setdefault(method, []).append(err.error_magnitude)
|
| 301 |
+
|
| 302 |
+
for method, count in processing_counter.most_common():
|
| 303 |
+
if count >= 2:
|
| 304 |
+
avg_severity = np.mean(processing_severity[method])
|
| 305 |
+
patterns.append(ErrorPattern(
|
| 306 |
+
pattern_type="processing",
|
| 307 |
+
description=f"Model fails frequently on {method} processed coffees",
|
| 308 |
+
frequency=count,
|
| 309 |
+
affected_queries=processing_errors[method],
|
| 310 |
+
severity=float(avg_severity),
|
| 311 |
+
))
|
| 312 |
+
|
| 313 |
+
# Pattern 3: Cross-origin confusion
|
| 314 |
+
confusion_counter: Counter[tuple[str, str]] = Counter()
|
| 315 |
+
confusion_errors: dict[tuple[str, str], list[int]] = {}
|
| 316 |
+
confusion_severity: dict[tuple[str, str], list[float]] = {}
|
| 317 |
+
|
| 318 |
+
for err in errors:
|
| 319 |
+
query_origin = err.query_metadata.get("Country of Origin", "Unknown")
|
| 320 |
+
rec_origin = err.recommended_metadata.get("Country of Origin", "Unknown")
|
| 321 |
+
if query_origin != rec_origin:
|
| 322 |
+
key = (query_origin, rec_origin)
|
| 323 |
+
confusion_counter[key] += 1
|
| 324 |
+
confusion_errors.setdefault(key, []).append(err.query_idx)
|
| 325 |
+
confusion_severity.setdefault(key, []).append(err.error_magnitude)
|
| 326 |
+
|
| 327 |
+
for (q_origin, r_origin), count in confusion_counter.most_common(5):
|
| 328 |
+
if count >= 2:
|
| 329 |
+
avg_severity = np.mean(confusion_severity[(q_origin, r_origin)])
|
| 330 |
+
patterns.append(ErrorPattern(
|
| 331 |
+
pattern_type="cross_origin_confusion",
|
| 332 |
+
description=f"Model confuses {q_origin} with {r_origin}",
|
| 333 |
+
frequency=count,
|
| 334 |
+
affected_queries=confusion_errors[(q_origin, r_origin)],
|
| 335 |
+
severity=float(avg_severity),
|
| 336 |
+
))
|
| 337 |
+
|
| 338 |
+
# Pattern 4: High acidity/low body confusion (taste profile patterns)
|
| 339 |
+
high_acidity_errors = []
|
| 340 |
+
high_acidity_indices = []
|
| 341 |
+
high_acidity_severities = []
|
| 342 |
+
|
| 343 |
+
low_body_errors = []
|
| 344 |
+
low_body_indices = []
|
| 345 |
+
low_body_severities = []
|
| 346 |
+
|
| 347 |
+
for err in errors:
|
| 348 |
+
# Check for high acidity queries (above 7.5 on typical 6-10 scale)
|
| 349 |
+
acidity_idx = 3 # Index of Acidity in taste features
|
| 350 |
+
if err.query_preferences[acidity_idx] > 7.5:
|
| 351 |
+
high_acidity_errors.append(err)
|
| 352 |
+
high_acidity_indices.append(err.query_idx)
|
| 353 |
+
high_acidity_severities.append(err.error_magnitude)
|
| 354 |
+
|
| 355 |
+
# Check for low body queries (below 7.0)
|
| 356 |
+
body_idx = 4 # Index of Body in taste features
|
| 357 |
+
if err.query_preferences[body_idx] < 7.0:
|
| 358 |
+
low_body_errors.append(err)
|
| 359 |
+
low_body_indices.append(err.query_idx)
|
| 360 |
+
low_body_severities.append(err.error_magnitude)
|
| 361 |
+
|
| 362 |
+
if len(high_acidity_errors) >= 2:
|
| 363 |
+
patterns.append(ErrorPattern(
|
| 364 |
+
pattern_type="taste_profile",
|
| 365 |
+
description="Model struggles with high-acidity coffee recommendations",
|
| 366 |
+
frequency=len(high_acidity_errors),
|
| 367 |
+
affected_queries=high_acidity_indices,
|
| 368 |
+
severity=float(np.mean(high_acidity_severities)),
|
| 369 |
+
))
|
| 370 |
+
|
| 371 |
+
if len(low_body_errors) >= 2:
|
| 372 |
+
patterns.append(ErrorPattern(
|
| 373 |
+
pattern_type="taste_profile",
|
| 374 |
+
description="Model struggles with low-body coffee recommendations",
|
| 375 |
+
frequency=len(low_body_errors),
|
| 376 |
+
affected_queries=low_body_indices,
|
| 377 |
+
severity=float(np.mean(low_body_severities)),
|
| 378 |
+
))
|
| 379 |
+
|
| 380 |
+
# Pattern 5: Rank degradation (first relevant item is far down)
|
| 381 |
+
severe_rank_errors = []
|
| 382 |
+
severe_rank_indices = []
|
| 383 |
+
severe_rank_severities = []
|
| 384 |
+
|
| 385 |
+
for err in errors:
|
| 386 |
+
if err.rank_of_first_relevant is not None and err.rank_of_first_relevant > 5:
|
| 387 |
+
severe_rank_errors.append(err)
|
| 388 |
+
severe_rank_indices.append(err.query_idx)
|
| 389 |
+
severe_rank_severities.append(err.error_magnitude)
|
| 390 |
+
|
| 391 |
+
if len(severe_rank_errors) >= 2:
|
| 392 |
+
avg_rank = np.mean([e.rank_of_first_relevant for e in severe_rank_errors
|
| 393 |
+
if e.rank_of_first_relevant is not None])
|
| 394 |
+
patterns.append(ErrorPattern(
|
| 395 |
+
pattern_type="ranking",
|
| 396 |
+
description=f"First relevant item ranked very low (avg rank: {avg_rank:.1f})",
|
| 397 |
+
frequency=len(severe_rank_errors),
|
| 398 |
+
affected_queries=severe_rank_indices,
|
| 399 |
+
severity=float(np.mean(severe_rank_severities)),
|
| 400 |
+
))
|
| 401 |
+
|
| 402 |
+
# Sort patterns by frequency
|
| 403 |
+
patterns.sort(key=lambda p: p.frequency, reverse=True)
|
| 404 |
+
return patterns
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _generate_root_cause(error: PredictionError) -> str:
|
| 408 |
+
"""Generate a root cause analysis for a single error."""
|
| 409 |
+
query_origin = error.query_metadata.get("Country of Origin", "Unknown")
|
| 410 |
+
query_processing = error.query_metadata.get("Processing Method", "Unknown")
|
| 411 |
+
rec_origin = error.recommended_metadata.get("Country of Origin", "Unknown")
|
| 412 |
+
rec_processing = error.recommended_metadata.get("Processing Method", "Unknown")
|
| 413 |
+
|
| 414 |
+
causes = []
|
| 415 |
+
|
| 416 |
+
# Check for origin mismatch
|
| 417 |
+
if query_origin != rec_origin:
|
| 418 |
+
causes.append(
|
| 419 |
+
f"Origin mismatch: queried {query_origin}, recommended {rec_origin}"
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Check for processing mismatch
|
| 423 |
+
if query_processing != rec_processing:
|
| 424 |
+
causes.append(
|
| 425 |
+
f"Processing mismatch: queried {query_processing}, recommended {rec_processing}"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Check for taste profile deviation
|
| 429 |
+
taste_features = [
|
| 430 |
+
"Aroma", "Flavor", "Aftertaste", "Acidity", "Body",
|
| 431 |
+
"Balance", "Uniformity", "Clean Cup", "Sweetness"
|
| 432 |
+
]
|
| 433 |
+
large_deviations = []
|
| 434 |
+
for i, feature in enumerate(taste_features):
|
| 435 |
+
diff = abs(error.query_preferences[i] - error.recommended_profile[i])
|
| 436 |
+
if diff > 0.5: # Significant deviation
|
| 437 |
+
large_deviations.append(f"{feature} (diff: {diff:.2f})")
|
| 438 |
+
|
| 439 |
+
if large_deviations:
|
| 440 |
+
causes.append(f"Large taste deviations: {', '.join(large_deviations[:3])}")
|
| 441 |
+
|
| 442 |
+
if not causes:
|
| 443 |
+
causes.append("Minor deviations across multiple dimensions")
|
| 444 |
+
|
| 445 |
+
return "; ".join(causes)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _generate_mitigations(patterns: list[ErrorPattern]) -> list[str]:
|
| 449 |
+
"""Generate mitigation strategies based on identified patterns."""
|
| 450 |
+
mitigations = []
|
| 451 |
+
|
| 452 |
+
pattern_types = {p.pattern_type for p in patterns}
|
| 453 |
+
|
| 454 |
+
if "origin" in pattern_types:
|
| 455 |
+
mitigations.append(
|
| 456 |
+
"Consider adding origin-aware features or embeddings to better "
|
| 457 |
+
"capture regional flavor characteristics."
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
if "processing" in pattern_types:
|
| 461 |
+
mitigations.append(
|
| 462 |
+
"Include processing method as an explicit feature or learn "
|
| 463 |
+
"processing-specific taste profile transformations."
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if "cross_origin_confusion" in pattern_types:
|
| 467 |
+
mitigations.append(
|
| 468 |
+
"Add contrastive learning or negative sampling to better distinguish "
|
| 469 |
+
"coffees from commonly confused origins."
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
if "taste_profile" in pattern_types:
|
| 473 |
+
mitigations.append(
|
| 474 |
+
"Review feature scaling and consider non-linear transformations "
|
| 475 |
+
"for extreme taste profile values (high acidity, low body)."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
if "ranking" in pattern_types:
|
| 479 |
+
mitigations.append(
|
| 480 |
+
"Incorporate a re-ranking stage or listwise learning objective "
|
| 481 |
+
"to improve early-rank precision."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# General mitigations based on error severity
|
| 485 |
+
if patterns and max(p.severity for p in patterns) > 2.0:
|
| 486 |
+
mitigations.append(
|
| 487 |
+
"High severity errors suggest the model may benefit from ensemble "
|
| 488 |
+
"methods or calibration techniques."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if not mitigations:
|
| 492 |
+
mitigations.append(
|
| 493 |
+
"No strong error patterns detected. Consider increasing training "
|
| 494 |
+
"data or fine-tuning hyperparameters."
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
return mitigations
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def generate_error_report(
|
| 501 |
+
model: Recommender,
|
| 502 |
+
test_data: dict[str, Any],
|
| 503 |
+
) -> ErrorReport:
|
| 504 |
+
"""Generate a comprehensive error analysis report.
|
| 505 |
+
|
| 506 |
+
Analyzes the model's predictions on test data to identify:
|
| 507 |
+
- The 5 worst mispredictions with root cause analysis
|
| 508 |
+
- Common failure patterns across errors
|
| 509 |
+
- Proposed mitigation strategies
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
model: A fitted recommender model with a recommend() method.
|
| 513 |
+
test_data: Dictionary containing:
|
| 514 |
+
- 'X': Feature matrix of shape (n_samples, 9) with taste profiles.
|
| 515 |
+
- 'metadata': List of metadata dicts or DataFrame.
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
ErrorReport containing detailed analysis and recommendations.
|
| 519 |
+
|
| 520 |
+
Example:
|
| 521 |
+
>>> report = generate_error_report(model, test_data)
|
| 522 |
+
>>> print(f"Error rate: {report.error_rate:.1%}")
|
| 523 |
+
>>> for pattern in report.patterns:
|
| 524 |
+
... print(f"- {pattern.description}")
|
| 525 |
+
>>> for mitigation in report.mitigations:
|
| 526 |
+
... print(f"* {mitigation}")
|
| 527 |
+
"""
|
| 528 |
+
X = np.asarray(test_data["X"], dtype=np.float32)
|
| 529 |
+
metadata_raw = test_data["metadata"]
|
| 530 |
+
|
| 531 |
+
if hasattr(metadata_raw, "to_dict"):
|
| 532 |
+
all_metadata = metadata_raw.to_dict("records")
|
| 533 |
+
else:
|
| 534 |
+
all_metadata = list(metadata_raw)
|
| 535 |
+
|
| 536 |
+
n_samples = len(X)
|
| 537 |
+
|
| 538 |
+
# Get worst errors (more than 5 for pattern analysis)
|
| 539 |
+
all_errors = analyze_errors(model, test_data, n_errors=50)
|
| 540 |
+
worst_5_errors = all_errors[:5]
|
| 541 |
+
|
| 542 |
+
# Count total errors (queries where top recommendation is not relevant)
|
| 543 |
+
total_errors = 0
|
| 544 |
+
total_valid_queries = 0
|
| 545 |
+
|
| 546 |
+
for query_idx in range(n_samples):
|
| 547 |
+
query_profile = X[query_idx]
|
| 548 |
+
query_metadata = all_metadata[query_idx]
|
| 549 |
+
|
| 550 |
+
relevant = _find_relevant_items(
|
| 551 |
+
query_metadata=query_metadata,
|
| 552 |
+
query_profile=query_profile,
|
| 553 |
+
all_metadata=all_metadata,
|
| 554 |
+
all_profiles=X,
|
| 555 |
+
query_idx=query_idx,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if not relevant:
|
| 559 |
+
continue
|
| 560 |
+
|
| 561 |
+
total_valid_queries += 1
|
| 562 |
+
|
| 563 |
+
recommendations = model.recommend(query_profile, k=1)
|
| 564 |
+
if recommendations and recommendations[0]["index"] not in relevant:
|
| 565 |
+
total_errors += 1
|
| 566 |
+
|
| 567 |
+
# Identify patterns from all errors
|
| 568 |
+
patterns = identify_error_patterns(all_errors)
|
| 569 |
+
|
| 570 |
+
# Generate mitigations
|
| 571 |
+
mitigations = _generate_mitigations(patterns)
|
| 572 |
+
|
| 573 |
+
# Add root cause to worst errors (stored in metadata for reporting)
|
| 574 |
+
for err in worst_5_errors:
|
| 575 |
+
err.query_metadata["_root_cause"] = _generate_root_cause(err)
|
| 576 |
+
|
| 577 |
+
error_rate = total_errors / total_valid_queries if total_valid_queries > 0 else 0.0
|
| 578 |
+
|
| 579 |
+
return ErrorReport(
|
| 580 |
+
total_queries=total_valid_queries,
|
| 581 |
+
total_errors=total_errors,
|
| 582 |
+
error_rate=error_rate,
|
| 583 |
+
worst_errors=worst_5_errors,
|
| 584 |
+
patterns=patterns,
|
| 585 |
+
mitigations=mitigations,
|
| 586 |
+
)
|
src/brewmatch/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation metrics for the coffee recommendation system.
|
| 2 |
+
|
| 3 |
+
This module provides metrics for evaluating recommendation quality:
|
| 4 |
+
- Ranking metrics: Precision@K, Recall@K, NDCG@K
|
| 5 |
+
- Regression metrics: MSE, MAE for quality prediction
|
| 6 |
+
- Comprehensive evaluation combining all metrics
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Protocol, runtime_checkable
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from numpy.typing import ArrayLike
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@runtime_checkable
|
| 16 |
+
class Recommender(Protocol):
|
| 17 |
+
"""Protocol for recommender models used in evaluation."""
|
| 18 |
+
|
| 19 |
+
def recommend(self, preferences: np.ndarray, k: int = 5) -> list[dict[str, Any]]:
|
| 20 |
+
"""Recommend coffees matching user taste preferences."""
|
| 21 |
+
...
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def precision_at_k(
|
| 25 |
+
recommended: list[int],
|
| 26 |
+
relevant: set[int],
|
| 27 |
+
k: int,
|
| 28 |
+
) -> float:
|
| 29 |
+
"""Calculate Precision@K for a recommendation list.
|
| 30 |
+
|
| 31 |
+
Precision@K measures the proportion of recommended items in the top-K
|
| 32 |
+
that are relevant.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
recommended: List of recommended item indices, ordered by rank.
|
| 36 |
+
relevant: Set of relevant item indices.
|
| 37 |
+
k: Number of top recommendations to consider.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Precision@K score in range [0, 1].
|
| 41 |
+
|
| 42 |
+
Example:
|
| 43 |
+
>>> recommended = [1, 2, 3, 4, 5]
|
| 44 |
+
>>> relevant = {1, 3, 5, 7, 9}
|
| 45 |
+
>>> precision_at_k(recommended, relevant, k=5)
|
| 46 |
+
0.6
|
| 47 |
+
"""
|
| 48 |
+
if k <= 0:
|
| 49 |
+
raise ValueError(f"k must be positive, got {k}")
|
| 50 |
+
if not recommended:
|
| 51 |
+
return 0.0
|
| 52 |
+
|
| 53 |
+
top_k = recommended[:k]
|
| 54 |
+
hits = sum(1 for item in top_k if item in relevant)
|
| 55 |
+
return hits / k
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def recall_at_k(
|
| 59 |
+
recommended: list[int],
|
| 60 |
+
relevant: set[int],
|
| 61 |
+
k: int,
|
| 62 |
+
) -> float:
|
| 63 |
+
"""Calculate Recall@K for a recommendation list.
|
| 64 |
+
|
| 65 |
+
Recall@K measures the proportion of relevant items that appear
|
| 66 |
+
in the top-K recommendations.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
recommended: List of recommended item indices, ordered by rank.
|
| 70 |
+
relevant: Set of relevant item indices.
|
| 71 |
+
k: Number of top recommendations to consider.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Recall@K score in range [0, 1]. Returns 0.0 if there are no relevant items.
|
| 75 |
+
|
| 76 |
+
Example:
|
| 77 |
+
>>> recommended = [1, 2, 3, 4, 5]
|
| 78 |
+
>>> relevant = {1, 3, 5, 7, 9}
|
| 79 |
+
>>> recall_at_k(recommended, relevant, k=5)
|
| 80 |
+
0.6
|
| 81 |
+
"""
|
| 82 |
+
if k <= 0:
|
| 83 |
+
raise ValueError(f"k must be positive, got {k}")
|
| 84 |
+
if not relevant:
|
| 85 |
+
return 0.0
|
| 86 |
+
if not recommended:
|
| 87 |
+
return 0.0
|
| 88 |
+
|
| 89 |
+
top_k = recommended[:k]
|
| 90 |
+
hits = sum(1 for item in top_k if item in relevant)
|
| 91 |
+
return hits / len(relevant)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def ndcg_at_k(
|
| 95 |
+
recommended: list[int],
|
| 96 |
+
relevant: set[int],
|
| 97 |
+
k: int,
|
| 98 |
+
) -> float:
|
| 99 |
+
"""Calculate Normalized Discounted Cumulative Gain at K.
|
| 100 |
+
|
| 101 |
+
NDCG@K measures ranking quality by giving higher scores when
|
| 102 |
+
relevant items appear earlier in the recommendation list.
|
| 103 |
+
|
| 104 |
+
Uses binary relevance (1 if relevant, 0 otherwise).
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
recommended: List of recommended item indices, ordered by rank.
|
| 108 |
+
relevant: Set of relevant item indices.
|
| 109 |
+
k: Number of top recommendations to consider.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
NDCG@K score in range [0, 1]. Returns 0.0 if there are no relevant items.
|
| 113 |
+
|
| 114 |
+
Example:
|
| 115 |
+
>>> recommended = [1, 2, 3, 4, 5]
|
| 116 |
+
>>> relevant = {1, 3, 5}
|
| 117 |
+
>>> ndcg_at_k(recommended, relevant, k=5) # Higher because relevant items are ranked well
|
| 118 |
+
0.934...
|
| 119 |
+
"""
|
| 120 |
+
if k <= 0:
|
| 121 |
+
raise ValueError(f"k must be positive, got {k}")
|
| 122 |
+
if not relevant:
|
| 123 |
+
return 0.0
|
| 124 |
+
if not recommended:
|
| 125 |
+
return 0.0
|
| 126 |
+
|
| 127 |
+
top_k = recommended[:k]
|
| 128 |
+
|
| 129 |
+
# DCG: sum of relevance / log2(position + 1)
|
| 130 |
+
dcg = 0.0
|
| 131 |
+
for i, item in enumerate(top_k):
|
| 132 |
+
if item in relevant:
|
| 133 |
+
# Position is 1-indexed for the log
|
| 134 |
+
dcg += 1.0 / np.log2(i + 2)
|
| 135 |
+
|
| 136 |
+
# Ideal DCG: all relevant items ranked first
|
| 137 |
+
ideal_k = min(k, len(relevant))
|
| 138 |
+
idcg = sum(1.0 / np.log2(i + 2) for i in range(ideal_k))
|
| 139 |
+
|
| 140 |
+
if idcg == 0:
|
| 141 |
+
return 0.0
|
| 142 |
+
|
| 143 |
+
return dcg / idcg
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def mean_squared_error(
|
| 147 |
+
predicted: ArrayLike,
|
| 148 |
+
actual: ArrayLike,
|
| 149 |
+
) -> float:
|
| 150 |
+
"""Calculate Mean Squared Error between predictions and actual values.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
predicted: Predicted values.
|
| 154 |
+
actual: Actual/ground truth values.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Mean squared error (non-negative).
|
| 158 |
+
|
| 159 |
+
Raises:
|
| 160 |
+
ValueError: If arrays have different lengths.
|
| 161 |
+
|
| 162 |
+
Example:
|
| 163 |
+
>>> predicted = [3.0, 4.0, 5.0]
|
| 164 |
+
>>> actual = [3.5, 4.0, 4.5]
|
| 165 |
+
>>> mean_squared_error(predicted, actual)
|
| 166 |
+
0.166...
|
| 167 |
+
"""
|
| 168 |
+
predicted = np.asarray(predicted, dtype=np.float64)
|
| 169 |
+
actual = np.asarray(actual, dtype=np.float64)
|
| 170 |
+
|
| 171 |
+
if predicted.shape != actual.shape:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Shape mismatch: predicted {predicted.shape} vs actual {actual.shape}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return float(np.mean((predicted - actual) ** 2))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def mean_absolute_error(
|
| 180 |
+
predicted: ArrayLike,
|
| 181 |
+
actual: ArrayLike,
|
| 182 |
+
) -> float:
|
| 183 |
+
"""Calculate Mean Absolute Error between predictions and actual values.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
predicted: Predicted values.
|
| 187 |
+
actual: Actual/ground truth values.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Mean absolute error (non-negative).
|
| 191 |
+
|
| 192 |
+
Raises:
|
| 193 |
+
ValueError: If arrays have different lengths.
|
| 194 |
+
|
| 195 |
+
Example:
|
| 196 |
+
>>> predicted = [3.0, 4.0, 5.0]
|
| 197 |
+
>>> actual = [3.5, 4.0, 4.5]
|
| 198 |
+
>>> mean_absolute_error(predicted, actual)
|
| 199 |
+
0.333...
|
| 200 |
+
"""
|
| 201 |
+
predicted = np.asarray(predicted, dtype=np.float64)
|
| 202 |
+
actual = np.asarray(actual, dtype=np.float64)
|
| 203 |
+
|
| 204 |
+
if predicted.shape != actual.shape:
|
| 205 |
+
raise ValueError(
|
| 206 |
+
f"Shape mismatch: predicted {predicted.shape} vs actual {actual.shape}"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return float(np.mean(np.abs(predicted - actual)))
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _compute_similarity(profile_a: np.ndarray, profile_b: np.ndarray) -> float:
|
| 213 |
+
"""Compute cosine similarity between two taste profiles."""
|
| 214 |
+
norm_a = np.linalg.norm(profile_a)
|
| 215 |
+
norm_b = np.linalg.norm(profile_b)
|
| 216 |
+
if norm_a == 0 or norm_b == 0:
|
| 217 |
+
return 0.0
|
| 218 |
+
return float(np.dot(profile_a, profile_b) / (norm_a * norm_b))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _find_relevant_items(
|
| 222 |
+
query_metadata: dict[str, Any],
|
| 223 |
+
query_profile: np.ndarray,
|
| 224 |
+
all_metadata: list[dict[str, Any]],
|
| 225 |
+
all_profiles: np.ndarray,
|
| 226 |
+
query_idx: int,
|
| 227 |
+
similarity_threshold: float = 0.95,
|
| 228 |
+
) -> set[int]:
|
| 229 |
+
"""Identify relevant items for a query coffee.
|
| 230 |
+
|
| 231 |
+
An item is considered relevant if:
|
| 232 |
+
- It shares the same country AND processing method, OR
|
| 233 |
+
- It has high taste profile similarity (>= threshold)
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
query_metadata: Metadata dict for the query coffee.
|
| 237 |
+
query_profile: Taste profile array for the query coffee.
|
| 238 |
+
all_metadata: List of metadata dicts for all coffees.
|
| 239 |
+
all_profiles: Array of all taste profiles, shape (n_samples, n_features).
|
| 240 |
+
query_idx: Index of the query item (excluded from relevant set).
|
| 241 |
+
similarity_threshold: Cosine similarity threshold for profile-based relevance.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Set of indices of relevant items.
|
| 245 |
+
"""
|
| 246 |
+
relevant = set()
|
| 247 |
+
query_country = query_metadata.get("Country of Origin", "")
|
| 248 |
+
query_processing = query_metadata.get("Processing Method", "")
|
| 249 |
+
|
| 250 |
+
for i, meta in enumerate(all_metadata):
|
| 251 |
+
if i == query_idx:
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
# Check metadata-based relevance
|
| 255 |
+
same_country = meta.get("Country of Origin", "") == query_country
|
| 256 |
+
same_processing = meta.get("Processing Method", "") == query_processing
|
| 257 |
+
|
| 258 |
+
if same_country and same_processing and query_country and query_processing:
|
| 259 |
+
relevant.add(i)
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
# Check similarity-based relevance
|
| 263 |
+
similarity = _compute_similarity(query_profile, all_profiles[i])
|
| 264 |
+
if similarity >= similarity_threshold:
|
| 265 |
+
relevant.add(i)
|
| 266 |
+
|
| 267 |
+
return relevant
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def evaluate_model(
|
| 271 |
+
model: Recommender,
|
| 272 |
+
test_data: dict[str, Any],
|
| 273 |
+
k_values: list[int] | None = None,
|
| 274 |
+
) -> dict[str, Any]:
|
| 275 |
+
"""Comprehensive evaluation of a recommendation model.
|
| 276 |
+
|
| 277 |
+
Evaluates the model using each test coffee as a query, measuring how well
|
| 278 |
+
the model recommends similar coffees (same country/processing or high similarity).
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
model: A fitted recommender model with a recommend() method.
|
| 282 |
+
test_data: Dictionary containing:
|
| 283 |
+
- 'X': Feature matrix of shape (n_samples, 9) with taste profiles.
|
| 284 |
+
- 'metadata': List of metadata dicts or DataFrame.
|
| 285 |
+
k_values: List of K values for ranking metrics. Defaults to [1, 3, 5, 10].
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Dictionary containing:
|
| 289 |
+
- 'precision@k': Dict mapping k to average Precision@K.
|
| 290 |
+
- 'recall@k': Dict mapping k to average Recall@K.
|
| 291 |
+
- 'ndcg@k': Dict mapping k to average NDCG@K.
|
| 292 |
+
- 'mse': Mean squared error of predicted vs actual taste profiles.
|
| 293 |
+
- 'mae': Mean absolute error of predicted vs actual taste profiles.
|
| 294 |
+
- 'n_queries': Number of test queries evaluated.
|
| 295 |
+
- 'avg_relevant_items': Average number of relevant items per query.
|
| 296 |
+
|
| 297 |
+
Example:
|
| 298 |
+
>>> results = evaluate_model(model, test_data, k_values=[1, 5, 10])
|
| 299 |
+
>>> print(f"Precision@5: {results['precision@k'][5]:.3f}")
|
| 300 |
+
>>> print(f"NDCG@10: {results['ndcg@k'][10]:.3f}")
|
| 301 |
+
"""
|
| 302 |
+
if k_values is None:
|
| 303 |
+
k_values = [1, 3, 5, 10]
|
| 304 |
+
|
| 305 |
+
X = np.asarray(test_data["X"], dtype=np.float32)
|
| 306 |
+
metadata_raw = test_data["metadata"]
|
| 307 |
+
|
| 308 |
+
# Convert metadata to list of dicts if it's a DataFrame
|
| 309 |
+
if hasattr(metadata_raw, "to_dict"):
|
| 310 |
+
all_metadata = metadata_raw.to_dict("records")
|
| 311 |
+
else:
|
| 312 |
+
all_metadata = list(metadata_raw)
|
| 313 |
+
|
| 314 |
+
n_samples = len(X)
|
| 315 |
+
max_k = max(k_values)
|
| 316 |
+
|
| 317 |
+
# Initialize accumulators
|
| 318 |
+
precision_sums = {k: 0.0 for k in k_values}
|
| 319 |
+
recall_sums = {k: 0.0 for k in k_values}
|
| 320 |
+
ndcg_sums = {k: 0.0 for k in k_values}
|
| 321 |
+
total_relevant = 0
|
| 322 |
+
valid_queries = 0
|
| 323 |
+
|
| 324 |
+
# For MSE/MAE: collect predicted vs actual taste profiles
|
| 325 |
+
all_predicted_profiles = []
|
| 326 |
+
all_actual_profiles = []
|
| 327 |
+
|
| 328 |
+
for query_idx in range(n_samples):
|
| 329 |
+
query_profile = X[query_idx]
|
| 330 |
+
query_metadata = all_metadata[query_idx]
|
| 331 |
+
|
| 332 |
+
# Find relevant items for this query
|
| 333 |
+
relevant = _find_relevant_items(
|
| 334 |
+
query_metadata=query_metadata,
|
| 335 |
+
query_profile=query_profile,
|
| 336 |
+
all_metadata=all_metadata,
|
| 337 |
+
all_profiles=X,
|
| 338 |
+
query_idx=query_idx,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Skip queries with no relevant items
|
| 342 |
+
if not relevant:
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
valid_queries += 1
|
| 346 |
+
total_relevant += len(relevant)
|
| 347 |
+
|
| 348 |
+
# Get recommendations
|
| 349 |
+
recommendations = model.recommend(query_profile, k=max_k)
|
| 350 |
+
recommended_indices = [rec["index"] for rec in recommendations]
|
| 351 |
+
|
| 352 |
+
# Calculate ranking metrics for each k
|
| 353 |
+
for k in k_values:
|
| 354 |
+
precision_sums[k] += precision_at_k(recommended_indices, relevant, k)
|
| 355 |
+
recall_sums[k] += recall_at_k(recommended_indices, relevant, k)
|
| 356 |
+
ndcg_sums[k] += ndcg_at_k(recommended_indices, relevant, k)
|
| 357 |
+
|
| 358 |
+
# For MSE/MAE: compare top recommendation's profile to query profile
|
| 359 |
+
if recommendations:
|
| 360 |
+
top_rec = recommendations[0]
|
| 361 |
+
predicted_profile = np.array(
|
| 362 |
+
[top_rec["taste_profile"][f] for f in [
|
| 363 |
+
"Aroma", "Flavor", "Aftertaste", "Acidity", "Body",
|
| 364 |
+
"Balance", "Uniformity", "Clean Cup", "Sweetness"
|
| 365 |
+
]],
|
| 366 |
+
dtype=np.float32
|
| 367 |
+
)
|
| 368 |
+
all_predicted_profiles.append(predicted_profile)
|
| 369 |
+
all_actual_profiles.append(query_profile)
|
| 370 |
+
|
| 371 |
+
# Compute averages
|
| 372 |
+
results: dict[str, Any] = {
|
| 373 |
+
"precision@k": {},
|
| 374 |
+
"recall@k": {},
|
| 375 |
+
"ndcg@k": {},
|
| 376 |
+
"n_queries": valid_queries,
|
| 377 |
+
"avg_relevant_items": total_relevant / valid_queries if valid_queries > 0 else 0.0,
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
if valid_queries > 0:
|
| 381 |
+
for k in k_values:
|
| 382 |
+
results["precision@k"][k] = precision_sums[k] / valid_queries
|
| 383 |
+
results["recall@k"][k] = recall_sums[k] / valid_queries
|
| 384 |
+
results["ndcg@k"][k] = ndcg_sums[k] / valid_queries
|
| 385 |
+
else:
|
| 386 |
+
for k in k_values:
|
| 387 |
+
results["precision@k"][k] = 0.0
|
| 388 |
+
results["recall@k"][k] = 0.0
|
| 389 |
+
results["ndcg@k"][k] = 0.0
|
| 390 |
+
|
| 391 |
+
# Compute MSE and MAE
|
| 392 |
+
if all_predicted_profiles:
|
| 393 |
+
predicted = np.array(all_predicted_profiles)
|
| 394 |
+
actual = np.array(all_actual_profiles)
|
| 395 |
+
results["mse"] = mean_squared_error(predicted.flatten(), actual.flatten())
|
| 396 |
+
results["mae"] = mean_absolute_error(predicted.flatten(), actual.flatten())
|
| 397 |
+
else:
|
| 398 |
+
results["mse"] = float("nan")
|
| 399 |
+
results["mae"] = float("nan")
|
| 400 |
+
|
| 401 |
+
return results
|
src/brewmatch/experiment.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Focused Experiment: Training Set Size Sensitivity Analysis
|
| 3 |
+
|
| 4 |
+
This experiment investigates how model performance varies with training set size.
|
| 5 |
+
We train all three models (baseline, classical, neural) on progressively larger
|
| 6 |
+
subsets of the training data and measure their performance on a held-out test set.
|
| 7 |
+
|
| 8 |
+
Hypothesis: Deep learning model will show greater improvement with more data,
|
| 9 |
+
while classical models may plateau earlier.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import seaborn as sns
|
| 21 |
+
import torch
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
from brewmatch.config import (
|
| 25 |
+
K_VALUES,
|
| 26 |
+
NEURAL_CONFIG,
|
| 27 |
+
PROJECT_ROOT,
|
| 28 |
+
RANDOM_SEED,
|
| 29 |
+
TASTE_FEATURES,
|
| 30 |
+
)
|
| 31 |
+
from brewmatch.data import load_processed_data
|
| 32 |
+
from brewmatch.models import (
|
| 33 |
+
NaiveBaselineRecommender,
|
| 34 |
+
ClassicalMLRecommender,
|
| 35 |
+
NeuralRecommender,
|
| 36 |
+
)
|
| 37 |
+
from brewmatch.evaluation import evaluate_model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Experiment configuration
|
| 41 |
+
TRAIN_FRACTIONS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
| 42 |
+
N_TRIALS = 3 # Number of trials per fraction for variance estimation
|
| 43 |
+
RESULTS_DIR = PROJECT_ROOT / "experiments"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def subsample_data(
|
| 47 |
+
df: pd.DataFrame,
|
| 48 |
+
fraction: float,
|
| 49 |
+
seed: int,
|
| 50 |
+
) -> pd.DataFrame:
|
| 51 |
+
"""Subsample training data to a given fraction."""
|
| 52 |
+
np.random.seed(seed)
|
| 53 |
+
n_samples = int(len(df) * fraction)
|
| 54 |
+
indices = np.random.choice(len(df), n_samples, replace=False)
|
| 55 |
+
return df.iloc[indices].reset_index(drop=True)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def train_and_evaluate_baseline(
|
| 59 |
+
train_df: pd.DataFrame,
|
| 60 |
+
test_df: pd.DataFrame,
|
| 61 |
+
) -> dict[str, Any]:
|
| 62 |
+
"""Train and evaluate baseline model."""
|
| 63 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 64 |
+
model = NaiveBaselineRecommender(strategy="mean")
|
| 65 |
+
model.fit(X_train, train_df)
|
| 66 |
+
|
| 67 |
+
test_data = {
|
| 68 |
+
"X": test_df[TASTE_FEATURES].values,
|
| 69 |
+
"metadata": test_df,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
return evaluate_model(
|
| 73 |
+
model=model,
|
| 74 |
+
test_data=test_data,
|
| 75 |
+
k_values=K_VALUES,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def train_and_evaluate_classical(
|
| 80 |
+
train_df: pd.DataFrame,
|
| 81 |
+
test_df: pd.DataFrame,
|
| 82 |
+
) -> dict[str, Any]:
|
| 83 |
+
"""Train and evaluate classical ML model."""
|
| 84 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 85 |
+
model = ClassicalMLRecommender(method="knn", n_neighbors=50, normalize=True)
|
| 86 |
+
model.fit(X_train, train_df)
|
| 87 |
+
|
| 88 |
+
test_data = {
|
| 89 |
+
"X": test_df[TASTE_FEATURES].values,
|
| 90 |
+
"metadata": test_df,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return evaluate_model(
|
| 94 |
+
model=model,
|
| 95 |
+
test_data=test_data,
|
| 96 |
+
k_values=K_VALUES,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def train_and_evaluate_neural(
|
| 101 |
+
train_df: pd.DataFrame,
|
| 102 |
+
test_df: pd.DataFrame,
|
| 103 |
+
device: str,
|
| 104 |
+
) -> dict[str, Any]:
|
| 105 |
+
"""Train and evaluate neural network model."""
|
| 106 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 107 |
+
|
| 108 |
+
model = NeuralRecommender(
|
| 109 |
+
embedding_dim=NEURAL_CONFIG["embedding_dim"],
|
| 110 |
+
hidden_dim=NEURAL_CONFIG.get("hidden_dim", 64),
|
| 111 |
+
learning_rate=NEURAL_CONFIG["learning_rate"],
|
| 112 |
+
margin=NEURAL_CONFIG["margin"],
|
| 113 |
+
device=device,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Use reduced epochs for experiment speed
|
| 117 |
+
model.fit(
|
| 118 |
+
X=X_train,
|
| 119 |
+
metadata=train_df,
|
| 120 |
+
epochs=30, # Reduced for speed
|
| 121 |
+
batch_size=NEURAL_CONFIG["batch_size"],
|
| 122 |
+
verbose=False,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
test_data = {
|
| 126 |
+
"X": test_df[TASTE_FEATURES].values,
|
| 127 |
+
"metadata": test_df,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
return evaluate_model(
|
| 131 |
+
model=model,
|
| 132 |
+
test_data=test_data,
|
| 133 |
+
k_values=K_VALUES,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def run_experiment(
|
| 138 |
+
train_df: pd.DataFrame,
|
| 139 |
+
test_df: pd.DataFrame,
|
| 140 |
+
device: str,
|
| 141 |
+
fractions: list[float] = TRAIN_FRACTIONS,
|
| 142 |
+
n_trials: int = N_TRIALS,
|
| 143 |
+
) -> dict[str, dict[str, list[dict[str, Any]]]]:
|
| 144 |
+
"""
|
| 145 |
+
Run the full sensitivity analysis experiment.
|
| 146 |
+
|
| 147 |
+
Returns nested dict: {model_name: {fraction: [trial_results]}}
|
| 148 |
+
"""
|
| 149 |
+
results = {
|
| 150 |
+
"baseline": {str(f): [] for f in fractions},
|
| 151 |
+
"classical": {str(f): [] for f in fractions},
|
| 152 |
+
"neural": {str(f): [] for f in fractions},
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
total_runs = len(fractions) * n_trials * 3
|
| 156 |
+
pbar = tqdm(total=total_runs, desc="Running experiment")
|
| 157 |
+
|
| 158 |
+
for fraction in fractions:
|
| 159 |
+
for trial in range(n_trials):
|
| 160 |
+
seed = RANDOM_SEED + trial
|
| 161 |
+
|
| 162 |
+
# Subsample training data
|
| 163 |
+
sub_train_df = subsample_data(train_df, fraction, seed)
|
| 164 |
+
|
| 165 |
+
# Baseline
|
| 166 |
+
try:
|
| 167 |
+
baseline_metrics = train_and_evaluate_baseline(sub_train_df, test_df)
|
| 168 |
+
results["baseline"][str(fraction)].append(baseline_metrics)
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Baseline failed at fraction {fraction}, trial {trial}: {e}")
|
| 171 |
+
pbar.update(1)
|
| 172 |
+
|
| 173 |
+
# Classical
|
| 174 |
+
try:
|
| 175 |
+
classical_metrics = train_and_evaluate_classical(sub_train_df, test_df)
|
| 176 |
+
results["classical"][str(fraction)].append(classical_metrics)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"Classical failed at fraction {fraction}, trial {trial}: {e}")
|
| 179 |
+
pbar.update(1)
|
| 180 |
+
|
| 181 |
+
# Neural
|
| 182 |
+
try:
|
| 183 |
+
neural_metrics = train_and_evaluate_neural(sub_train_df, test_df, device)
|
| 184 |
+
results["neural"][str(fraction)].append(neural_metrics)
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Neural failed at fraction {fraction}, trial {trial}: {e}")
|
| 187 |
+
pbar.update(1)
|
| 188 |
+
|
| 189 |
+
pbar.close()
|
| 190 |
+
return results
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def aggregate_results(
|
| 194 |
+
results: dict[str, dict[str, list[dict[str, Any]]]]
|
| 195 |
+
) -> pd.DataFrame:
|
| 196 |
+
"""Aggregate results into a DataFrame with mean and std."""
|
| 197 |
+
rows = []
|
| 198 |
+
|
| 199 |
+
for model_name, fraction_results in results.items():
|
| 200 |
+
for fraction, trials in fraction_results.items():
|
| 201 |
+
if not trials:
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
# Flatten nested dicts and aggregate across trials
|
| 205 |
+
flat_metrics: dict[str, list[float]] = {}
|
| 206 |
+
for trial in trials:
|
| 207 |
+
for key, value in trial.items():
|
| 208 |
+
if isinstance(value, dict):
|
| 209 |
+
# Handle nested metrics like precision@k
|
| 210 |
+
for k, v in value.items():
|
| 211 |
+
metric_name = f"{key.replace('@k', '')}@{k}"
|
| 212 |
+
flat_metrics.setdefault(metric_name, []).append(v)
|
| 213 |
+
elif isinstance(value, (int, float)) and not isinstance(value, bool):
|
| 214 |
+
flat_metrics.setdefault(key, []).append(value)
|
| 215 |
+
|
| 216 |
+
# Compute mean and std
|
| 217 |
+
aggregated = {}
|
| 218 |
+
for metric, values in flat_metrics.items():
|
| 219 |
+
aggregated[f"{metric}_mean"] = np.mean(values)
|
| 220 |
+
aggregated[f"{metric}_std"] = np.std(values)
|
| 221 |
+
|
| 222 |
+
rows.append({
|
| 223 |
+
"model": model_name,
|
| 224 |
+
"fraction": float(fraction),
|
| 225 |
+
**aggregated,
|
| 226 |
+
})
|
| 227 |
+
|
| 228 |
+
return pd.DataFrame(rows)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def plot_results(df: pd.DataFrame, output_dir: Path) -> None:
|
| 232 |
+
"""Generate visualization of experiment results."""
|
| 233 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 234 |
+
|
| 235 |
+
# Set style
|
| 236 |
+
sns.set_style("whitegrid")
|
| 237 |
+
plt.rcParams["figure.figsize"] = (12, 8)
|
| 238 |
+
|
| 239 |
+
# Get main metric (Precision@5)
|
| 240 |
+
metric = "precision@5"
|
| 241 |
+
mean_col = f"{metric}_mean"
|
| 242 |
+
std_col = f"{metric}_std"
|
| 243 |
+
|
| 244 |
+
if mean_col not in df.columns:
|
| 245 |
+
# Try first available metric
|
| 246 |
+
metric_cols = [c for c in df.columns if c.endswith("_mean")]
|
| 247 |
+
if metric_cols:
|
| 248 |
+
mean_col = metric_cols[0]
|
| 249 |
+
std_col = mean_col.replace("_mean", "_std")
|
| 250 |
+
metric = mean_col.replace("_mean", "")
|
| 251 |
+
|
| 252 |
+
fig, ax = plt.subplots()
|
| 253 |
+
|
| 254 |
+
colors = {"baseline": "#e74c3c", "classical": "#3498db", "neural": "#2ecc71"}
|
| 255 |
+
|
| 256 |
+
for model in ["baseline", "classical", "neural"]:
|
| 257 |
+
model_df = df[df["model"] == model].sort_values("fraction")
|
| 258 |
+
|
| 259 |
+
if model_df.empty:
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
x = model_df["fraction"] * 100 # Convert to percentage
|
| 263 |
+
y = model_df[mean_col]
|
| 264 |
+
yerr = model_df[std_col] if std_col in model_df.columns else None
|
| 265 |
+
|
| 266 |
+
ax.errorbar(
|
| 267 |
+
x, y,
|
| 268 |
+
yerr=yerr,
|
| 269 |
+
label=model.capitalize(),
|
| 270 |
+
color=colors[model],
|
| 271 |
+
marker="o",
|
| 272 |
+
linewidth=2,
|
| 273 |
+
markersize=8,
|
| 274 |
+
capsize=3,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
ax.set_xlabel("Training Data Size (%)", fontsize=12)
|
| 278 |
+
ax.set_ylabel(f"{metric.replace('@', ' @ ').title()}", fontsize=12)
|
| 279 |
+
ax.set_title("Model Performance vs Training Set Size", fontsize=14)
|
| 280 |
+
ax.legend(fontsize=11)
|
| 281 |
+
ax.grid(True, alpha=0.3)
|
| 282 |
+
|
| 283 |
+
plt.tight_layout()
|
| 284 |
+
plt.savefig(output_dir / "sensitivity_analysis.png", dpi=150)
|
| 285 |
+
plt.close()
|
| 286 |
+
|
| 287 |
+
# Also create a multi-metric plot
|
| 288 |
+
metric_cols = [c for c in df.columns if c.endswith("_mean") and "@" in c]
|
| 289 |
+
if len(metric_cols) > 1:
|
| 290 |
+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
| 291 |
+
axes = axes.flatten()
|
| 292 |
+
|
| 293 |
+
for idx, mean_col in enumerate(metric_cols[:4]):
|
| 294 |
+
ax = axes[idx]
|
| 295 |
+
metric = mean_col.replace("_mean", "")
|
| 296 |
+
std_col = mean_col.replace("_mean", "_std")
|
| 297 |
+
|
| 298 |
+
for model in ["baseline", "classical", "neural"]:
|
| 299 |
+
model_df = df[df["model"] == model].sort_values("fraction")
|
| 300 |
+
if model_df.empty:
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
x = model_df["fraction"] * 100
|
| 304 |
+
y = model_df[mean_col]
|
| 305 |
+
yerr = model_df.get(std_col)
|
| 306 |
+
|
| 307 |
+
ax.errorbar(
|
| 308 |
+
x, y,
|
| 309 |
+
yerr=yerr,
|
| 310 |
+
label=model.capitalize(),
|
| 311 |
+
color=colors[model],
|
| 312 |
+
marker="o",
|
| 313 |
+
linewidth=2,
|
| 314 |
+
capsize=2,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
ax.set_xlabel("Training Data (%)")
|
| 318 |
+
ax.set_ylabel(metric.replace("@", " @ ").title())
|
| 319 |
+
ax.set_title(metric.replace("@", " @ ").title())
|
| 320 |
+
ax.legend(fontsize=9)
|
| 321 |
+
ax.grid(True, alpha=0.3)
|
| 322 |
+
|
| 323 |
+
plt.suptitle("Training Set Size Sensitivity Analysis", fontsize=14)
|
| 324 |
+
plt.tight_layout()
|
| 325 |
+
plt.savefig(output_dir / "sensitivity_analysis_multi.png", dpi=150)
|
| 326 |
+
plt.close()
|
| 327 |
+
|
| 328 |
+
print(f"Plots saved to {output_dir}")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def generate_report(df: pd.DataFrame, output_dir: Path) -> str:
|
| 332 |
+
"""Generate a text report of the experiment results."""
|
| 333 |
+
report = []
|
| 334 |
+
report.append("=" * 60)
|
| 335 |
+
report.append("SENSITIVITY ANALYSIS: TRAINING SET SIZE VS PERFORMANCE")
|
| 336 |
+
report.append("=" * 60)
|
| 337 |
+
report.append("")
|
| 338 |
+
|
| 339 |
+
# Summary statistics
|
| 340 |
+
report.append("EXPERIMENT SUMMARY")
|
| 341 |
+
report.append("-" * 40)
|
| 342 |
+
report.append(f"Training fractions tested: {sorted(df['fraction'].unique())}")
|
| 343 |
+
report.append(f"Models compared: {sorted(df['model'].unique())}")
|
| 344 |
+
report.append("")
|
| 345 |
+
|
| 346 |
+
# Best performance per model
|
| 347 |
+
report.append("BEST PERFORMANCE PER MODEL")
|
| 348 |
+
report.append("-" * 40)
|
| 349 |
+
|
| 350 |
+
metric_col = [c for c in df.columns if "precision" in c and "_mean" in c]
|
| 351 |
+
if metric_col:
|
| 352 |
+
metric_col = metric_col[0]
|
| 353 |
+
for model in ["baseline", "classical", "neural"]:
|
| 354 |
+
model_df = df[df["model"] == model]
|
| 355 |
+
if model_df.empty:
|
| 356 |
+
continue
|
| 357 |
+
best_idx = model_df[metric_col].idxmax()
|
| 358 |
+
best_row = model_df.loc[best_idx]
|
| 359 |
+
report.append(
|
| 360 |
+
f"{model.capitalize()}: {best_row[metric_col]:.4f} "
|
| 361 |
+
f"at {best_row['fraction']*100:.0f}% training data"
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
report.append("")
|
| 365 |
+
|
| 366 |
+
# Key findings
|
| 367 |
+
report.append("KEY FINDINGS")
|
| 368 |
+
report.append("-" * 40)
|
| 369 |
+
|
| 370 |
+
# Check if neural improves more with data
|
| 371 |
+
if "neural" in df["model"].values and metric_col:
|
| 372 |
+
neural_df = df[df["model"] == "neural"].sort_values("fraction")
|
| 373 |
+
if len(neural_df) >= 2:
|
| 374 |
+
start_perf = neural_df.iloc[0][metric_col]
|
| 375 |
+
end_perf = neural_df.iloc[-1][metric_col]
|
| 376 |
+
improvement = (end_perf - start_perf) / start_perf * 100
|
| 377 |
+
report.append(
|
| 378 |
+
f"1. Neural model improvement from 10% to 100% data: {improvement:.1f}%"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Compare final performance
|
| 382 |
+
if metric_col:
|
| 383 |
+
final_perfs = df[df["fraction"] == 1.0].set_index("model")[metric_col]
|
| 384 |
+
if len(final_perfs) > 0:
|
| 385 |
+
best_model = final_perfs.idxmax()
|
| 386 |
+
report.append(f"2. Best model at full data: {best_model}")
|
| 387 |
+
|
| 388 |
+
# Check for diminishing returns
|
| 389 |
+
report.append("3. Diminishing returns analysis: See sensitivity_analysis.png")
|
| 390 |
+
|
| 391 |
+
report.append("")
|
| 392 |
+
report.append("RECOMMENDATIONS")
|
| 393 |
+
report.append("-" * 40)
|
| 394 |
+
report.append("- If data collection is expensive, 50-70% of data may suffice")
|
| 395 |
+
report.append("- Neural model benefits most from additional data")
|
| 396 |
+
report.append("- Baseline provides a strong floor with minimal data")
|
| 397 |
+
|
| 398 |
+
report_text = "\n".join(report)
|
| 399 |
+
|
| 400 |
+
# Save report
|
| 401 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 402 |
+
with open(output_dir / "experiment_report.txt", "w") as f:
|
| 403 |
+
f.write(report_text)
|
| 404 |
+
|
| 405 |
+
return report_text
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def main():
|
| 409 |
+
"""Main experiment entry point."""
|
| 410 |
+
parser = argparse.ArgumentParser(
|
| 411 |
+
description="Run sensitivity analysis experiment"
|
| 412 |
+
)
|
| 413 |
+
parser.add_argument(
|
| 414 |
+
"--fractions",
|
| 415 |
+
nargs="+",
|
| 416 |
+
type=float,
|
| 417 |
+
default=TRAIN_FRACTIONS,
|
| 418 |
+
help="Training set fractions to test",
|
| 419 |
+
)
|
| 420 |
+
parser.add_argument(
|
| 421 |
+
"--trials",
|
| 422 |
+
type=int,
|
| 423 |
+
default=N_TRIALS,
|
| 424 |
+
help="Number of trials per fraction",
|
| 425 |
+
)
|
| 426 |
+
parser.add_argument(
|
| 427 |
+
"--device",
|
| 428 |
+
type=str,
|
| 429 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 430 |
+
help="Device for neural network training",
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument(
|
| 433 |
+
"--output-dir",
|
| 434 |
+
type=str,
|
| 435 |
+
default=str(RESULTS_DIR),
|
| 436 |
+
help="Directory to save results",
|
| 437 |
+
)
|
| 438 |
+
args = parser.parse_args()
|
| 439 |
+
|
| 440 |
+
output_dir = Path(args.output_dir)
|
| 441 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 442 |
+
|
| 443 |
+
print("SENSITIVITY ANALYSIS EXPERIMENT")
|
| 444 |
+
print("=" * 40)
|
| 445 |
+
print(f"Training fractions: {args.fractions}")
|
| 446 |
+
print(f"Trials per fraction: {args.trials}")
|
| 447 |
+
print(f"Device: {args.device}")
|
| 448 |
+
print(f"Output directory: {output_dir}")
|
| 449 |
+
print()
|
| 450 |
+
|
| 451 |
+
# Load data
|
| 452 |
+
print("Loading data...")
|
| 453 |
+
data = load_processed_data()
|
| 454 |
+
train_df = data["train_df"]
|
| 455 |
+
test_df = data["test_df"]
|
| 456 |
+
|
| 457 |
+
print(f"Training samples: {len(train_df)}")
|
| 458 |
+
print(f"Test samples: {len(test_df)}")
|
| 459 |
+
print()
|
| 460 |
+
|
| 461 |
+
# Run experiment
|
| 462 |
+
print("Running experiment...")
|
| 463 |
+
results = run_experiment(
|
| 464 |
+
train_df=train_df,
|
| 465 |
+
test_df=test_df,
|
| 466 |
+
device=args.device,
|
| 467 |
+
fractions=args.fractions,
|
| 468 |
+
n_trials=args.trials,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Save raw results
|
| 472 |
+
with open(output_dir / "raw_results.json", "w") as f:
|
| 473 |
+
json.dump(results, f, indent=2)
|
| 474 |
+
print(f"\nRaw results saved to {output_dir / 'raw_results.json'}")
|
| 475 |
+
|
| 476 |
+
# Aggregate and analyze
|
| 477 |
+
df = aggregate_results(results)
|
| 478 |
+
df.to_csv(output_dir / "aggregated_results.csv", index=False)
|
| 479 |
+
print(f"Aggregated results saved to {output_dir / 'aggregated_results.csv'}")
|
| 480 |
+
|
| 481 |
+
# Generate visualizations
|
| 482 |
+
plot_results(df, output_dir)
|
| 483 |
+
|
| 484 |
+
# Generate report
|
| 485 |
+
report = generate_report(df, output_dir)
|
| 486 |
+
print("\n" + report)
|
| 487 |
+
|
| 488 |
+
print("\nExperiment complete!")
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if __name__ == "__main__":
|
| 492 |
+
main()
|
src/brewmatch/models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ML models for coffee recommendation."""
|
| 2 |
+
|
| 3 |
+
from .base import BaseRecommender
|
| 4 |
+
from .baseline import NaiveBaselineRecommender
|
| 5 |
+
from .classical import ClassicalMLRecommender
|
| 6 |
+
from .neural import NeuralRecommender
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"BaseRecommender",
|
| 10 |
+
"NaiveBaselineRecommender",
|
| 11 |
+
"ClassicalMLRecommender",
|
| 12 |
+
"NeuralRecommender",
|
| 13 |
+
]
|
src/brewmatch/models/base.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract base class for all recommender models."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseRecommender(ABC):
|
| 12 |
+
"""Abstract base class for coffee recommender models.
|
| 13 |
+
|
| 14 |
+
All recommender implementations must inherit from this class and
|
| 15 |
+
implement the required methods.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
TASTE_FEATURES: The 9 taste feature columns used for recommendations.
|
| 19 |
+
is_fitted: Whether the model has been fitted to data.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
TASTE_FEATURES = [
|
| 23 |
+
"Aroma",
|
| 24 |
+
"Flavor",
|
| 25 |
+
"Aftertaste",
|
| 26 |
+
"Acidity",
|
| 27 |
+
"Body",
|
| 28 |
+
"Balance",
|
| 29 |
+
"Uniformity",
|
| 30 |
+
"Clean Cup",
|
| 31 |
+
"Sweetness",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
def __init__(self) -> None:
|
| 35 |
+
self.is_fitted = False
|
| 36 |
+
self._metadata: pd.DataFrame | None = None
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def fit(self, X: np.ndarray, metadata: pd.DataFrame) -> "BaseRecommender":
|
| 40 |
+
"""Fit the recommender to coffee taste profiles.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
X: Feature matrix of shape (n_samples, 9) containing taste scores.
|
| 44 |
+
Columns correspond to TASTE_FEATURES in order.
|
| 45 |
+
metadata: DataFrame containing coffee metadata (country, processing
|
| 46 |
+
method, variety, etc.). Must have same number of rows as X.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
self: The fitted recommender instance.
|
| 50 |
+
"""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def recommend(
|
| 55 |
+
self, preferences: np.ndarray, k: int = 5
|
| 56 |
+
) -> list[dict[str, Any]]:
|
| 57 |
+
"""Recommend coffees matching user taste preferences.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
preferences: Array of shape (9,) containing desired taste scores.
|
| 61 |
+
Values correspond to TASTE_FEATURES in order.
|
| 62 |
+
k: Number of recommendations to return.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
List of k recommendation dictionaries, each containing:
|
| 66 |
+
- 'index': Original index in the training data
|
| 67 |
+
- 'score': Similarity/relevance score (higher is better)
|
| 68 |
+
- 'metadata': Dict of coffee metadata
|
| 69 |
+
- 'taste_profile': Dict of the coffee's taste scores
|
| 70 |
+
"""
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
@abstractmethod
|
| 74 |
+
def save(self, path: str | Path) -> None:
|
| 75 |
+
"""Save the fitted model to disk.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
path: File path to save the model to.
|
| 79 |
+
"""
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
@abstractmethod
|
| 84 |
+
def load(cls, path: str | Path) -> "BaseRecommender":
|
| 85 |
+
"""Load a fitted model from disk.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
path: File path to load the model from.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
The loaded recommender instance.
|
| 92 |
+
"""
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
def _validate_fitted(self) -> None:
|
| 96 |
+
"""Raise error if model is not fitted."""
|
| 97 |
+
if not self.is_fitted:
|
| 98 |
+
raise RuntimeError(
|
| 99 |
+
f"{self.__class__.__name__} must be fitted before calling this method. "
|
| 100 |
+
"Call fit() first."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def _validate_preferences(self, preferences: np.ndarray) -> np.ndarray:
|
| 104 |
+
"""Validate and reshape preference array."""
|
| 105 |
+
preferences = np.asarray(preferences, dtype=np.float32)
|
| 106 |
+
if preferences.ndim == 1:
|
| 107 |
+
if preferences.shape[0] != 9:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"Expected 9 taste features, got {preferences.shape[0]}"
|
| 110 |
+
)
|
| 111 |
+
elif preferences.ndim == 2:
|
| 112 |
+
if preferences.shape[1] != 9:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"Expected 9 taste features, got {preferences.shape[1]}"
|
| 115 |
+
)
|
| 116 |
+
preferences = preferences.squeeze()
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Preferences must be 1D or 2D array, got {preferences.ndim}D"
|
| 120 |
+
)
|
| 121 |
+
return preferences
|
| 122 |
+
|
| 123 |
+
def _format_recommendation(
|
| 124 |
+
self, idx: int, score: float, taste_profile: np.ndarray
|
| 125 |
+
) -> dict[str, Any]:
|
| 126 |
+
"""Format a single recommendation as a dictionary."""
|
| 127 |
+
metadata_dict = {}
|
| 128 |
+
if self._metadata is not None:
|
| 129 |
+
row = self._metadata.iloc[idx]
|
| 130 |
+
metadata_dict = row.to_dict()
|
| 131 |
+
|
| 132 |
+
taste_dict = {
|
| 133 |
+
feature: float(taste_profile[i])
|
| 134 |
+
for i, feature in enumerate(self.TASTE_FEATURES)
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"index": idx,
|
| 139 |
+
"score": float(score),
|
| 140 |
+
"metadata": metadata_dict,
|
| 141 |
+
"taste_profile": taste_dict,
|
| 142 |
+
}
|
src/brewmatch/models/baseline.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Naive baseline recommender for establishing performance floor."""
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from .base import BaseRecommender
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NaiveBaselineRecommender(BaseRecommender):
|
| 14 |
+
"""Naive baseline recommender using global mean or weighted random.
|
| 15 |
+
|
| 16 |
+
This establishes a performance floor for comparison with more
|
| 17 |
+
sophisticated approaches. It supports two strategies:
|
| 18 |
+
|
| 19 |
+
- 'mean': Recommends coffees closest to the global mean profile,
|
| 20 |
+
ignoring user preferences entirely.
|
| 21 |
+
- 'weighted_random': Randomly samples coffees weighted by their
|
| 22 |
+
Total Cup Points score.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
strategy: The recommendation strategy ('mean' or 'weighted_random').
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, strategy: str = "mean") -> None:
|
| 29 |
+
"""Initialize the baseline recommender.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
strategy: Recommendation strategy. One of:
|
| 33 |
+
- 'mean': Recommend coffees closest to global mean profile
|
| 34 |
+
- 'weighted_random': Random sampling weighted by Total Cup Points
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
if strategy not in ("mean", "weighted_random"):
|
| 38 |
+
raise ValueError(f"Unknown strategy: {strategy}")
|
| 39 |
+
self.strategy = strategy
|
| 40 |
+
self._X: np.ndarray | None = None
|
| 41 |
+
self._global_mean: np.ndarray | None = None
|
| 42 |
+
self._weights: np.ndarray | None = None
|
| 43 |
+
self._rng = np.random.default_rng()
|
| 44 |
+
|
| 45 |
+
def fit(self, X: np.ndarray, metadata: pd.DataFrame) -> "NaiveBaselineRecommender":
|
| 46 |
+
"""Fit the baseline recommender.
|
| 47 |
+
|
| 48 |
+
For 'mean' strategy, computes the global mean profile and
|
| 49 |
+
distances from each coffee to it.
|
| 50 |
+
|
| 51 |
+
For 'weighted_random' strategy, extracts Total Cup Points as
|
| 52 |
+
sampling weights.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
X: Feature matrix of shape (n_samples, 9).
|
| 56 |
+
metadata: DataFrame with coffee metadata. For 'weighted_random',
|
| 57 |
+
must contain 'Total Cup Points' column.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
self: The fitted recommender.
|
| 61 |
+
"""
|
| 62 |
+
X = np.asarray(X, dtype=np.float32)
|
| 63 |
+
if X.shape[1] != 9:
|
| 64 |
+
raise ValueError(f"Expected 9 features, got {X.shape[1]}")
|
| 65 |
+
|
| 66 |
+
self._X = X
|
| 67 |
+
self._metadata = metadata.copy()
|
| 68 |
+
self._global_mean = X.mean(axis=0)
|
| 69 |
+
|
| 70 |
+
if self.strategy == "weighted_random":
|
| 71 |
+
if "Total Cup Points" not in metadata.columns:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"metadata must contain 'Total Cup Points' for weighted_random strategy"
|
| 74 |
+
)
|
| 75 |
+
scores = metadata["Total Cup Points"].values.astype(np.float32)
|
| 76 |
+
# Shift to positive and normalize
|
| 77 |
+
scores = scores - scores.min() + 1.0
|
| 78 |
+
self._weights = scores / scores.sum()
|
| 79 |
+
|
| 80 |
+
self.is_fitted = True
|
| 81 |
+
return self
|
| 82 |
+
|
| 83 |
+
def recommend(
|
| 84 |
+
self, preferences: np.ndarray, k: int = 5
|
| 85 |
+
) -> list[dict[str, Any]]:
|
| 86 |
+
"""Generate recommendations.
|
| 87 |
+
|
| 88 |
+
For 'mean' strategy, returns coffees closest to the global mean,
|
| 89 |
+
ignoring the provided preferences entirely.
|
| 90 |
+
|
| 91 |
+
For 'weighted_random' strategy, returns random coffees sampled
|
| 92 |
+
proportionally to their Total Cup Points.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
preferences: User taste preferences (ignored for baseline).
|
| 96 |
+
k: Number of recommendations.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
List of k recommendation dictionaries.
|
| 100 |
+
"""
|
| 101 |
+
self._validate_fitted()
|
| 102 |
+
preferences = self._validate_preferences(preferences)
|
| 103 |
+
|
| 104 |
+
n_samples = self._X.shape[0]
|
| 105 |
+
k = min(k, n_samples)
|
| 106 |
+
|
| 107 |
+
if self.strategy == "mean":
|
| 108 |
+
# Find coffees closest to global mean (ignoring user preferences)
|
| 109 |
+
distances = np.linalg.norm(self._X - self._global_mean, axis=1)
|
| 110 |
+
indices = np.argsort(distances)[:k]
|
| 111 |
+
# Convert distance to similarity score (higher is better)
|
| 112 |
+
scores = 1.0 / (1.0 + distances[indices])
|
| 113 |
+
else:
|
| 114 |
+
# Weighted random sampling
|
| 115 |
+
indices = self._rng.choice(
|
| 116 |
+
n_samples, size=k, replace=False, p=self._weights
|
| 117 |
+
)
|
| 118 |
+
# Use weight as score
|
| 119 |
+
scores = self._weights[indices]
|
| 120 |
+
|
| 121 |
+
recommendations = []
|
| 122 |
+
for idx, score in zip(indices, scores):
|
| 123 |
+
rec = self._format_recommendation(idx, score, self._X[idx])
|
| 124 |
+
recommendations.append(rec)
|
| 125 |
+
|
| 126 |
+
return recommendations
|
| 127 |
+
|
| 128 |
+
def save(self, path: str | Path) -> None:
|
| 129 |
+
"""Save the fitted model to disk using pickle.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
path: File path to save the model to.
|
| 133 |
+
"""
|
| 134 |
+
self._validate_fitted()
|
| 135 |
+
path = Path(path)
|
| 136 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
state = {
|
| 139 |
+
"strategy": self.strategy,
|
| 140 |
+
"X": self._X,
|
| 141 |
+
"metadata": self._metadata,
|
| 142 |
+
"global_mean": self._global_mean,
|
| 143 |
+
"weights": self._weights,
|
| 144 |
+
}
|
| 145 |
+
with open(path, "wb") as f:
|
| 146 |
+
pickle.dump(state, f)
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def load(cls, path: str | Path) -> "NaiveBaselineRecommender":
|
| 150 |
+
"""Load a fitted model from disk.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
path: File path to load the model from.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
The loaded recommender instance.
|
| 157 |
+
"""
|
| 158 |
+
with open(path, "rb") as f:
|
| 159 |
+
state = pickle.load(f)
|
| 160 |
+
|
| 161 |
+
model = cls(strategy=state["strategy"])
|
| 162 |
+
model._X = state["X"]
|
| 163 |
+
model._metadata = state["metadata"]
|
| 164 |
+
model._global_mean = state["global_mean"]
|
| 165 |
+
model._weights = state["weights"]
|
| 166 |
+
model.is_fitted = True
|
| 167 |
+
return model
|
src/brewmatch/models/classical.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Classical ML recommender using KNN and cosine similarity."""
|
| 2 |
+
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from sklearn.neighbors import NearestNeighbors
|
| 10 |
+
from sklearn.preprocessing import StandardScaler
|
| 11 |
+
|
| 12 |
+
from .base import BaseRecommender
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ClassicalMLRecommender(BaseRecommender):
|
| 16 |
+
"""Classical ML recommender using KNN or cosine similarity.
|
| 17 |
+
|
| 18 |
+
Finds coffees with taste profiles most similar to user preferences
|
| 19 |
+
using either KNN with Euclidean distance or cosine similarity ranking.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
method: Similarity method ('knn' or 'cosine').
|
| 23 |
+
n_neighbors: Maximum neighbors for KNN (used for internal index).
|
| 24 |
+
normalize: Whether to standardize features before similarity computation.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
method: str = "knn",
|
| 30 |
+
n_neighbors: int = 50,
|
| 31 |
+
normalize: bool = True,
|
| 32 |
+
) -> None:
|
| 33 |
+
"""Initialize the classical ML recommender.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
method: Similarity method. One of:
|
| 37 |
+
- 'knn': K-nearest neighbors with Euclidean distance
|
| 38 |
+
- 'cosine': Cosine similarity ranking
|
| 39 |
+
n_neighbors: Maximum neighbors to index for KNN. Actual k in
|
| 40 |
+
recommend() can be smaller.
|
| 41 |
+
normalize: Whether to standardize features (recommended for
|
| 42 |
+
Euclidean distance).
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
if method not in ("knn", "cosine"):
|
| 46 |
+
raise ValueError(f"Unknown method: {method}")
|
| 47 |
+
self.method = method
|
| 48 |
+
self.n_neighbors = n_neighbors
|
| 49 |
+
self.normalize = normalize
|
| 50 |
+
|
| 51 |
+
self._X: np.ndarray | None = None
|
| 52 |
+
self._X_normalized: np.ndarray | None = None
|
| 53 |
+
self._scaler: StandardScaler | None = None
|
| 54 |
+
self._knn: NearestNeighbors | None = None
|
| 55 |
+
|
| 56 |
+
def fit(self, X: np.ndarray, metadata: pd.DataFrame) -> "ClassicalMLRecommender":
|
| 57 |
+
"""Fit the recommender to coffee taste profiles.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
X: Feature matrix of shape (n_samples, 9).
|
| 61 |
+
metadata: DataFrame with coffee metadata.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
self: The fitted recommender.
|
| 65 |
+
"""
|
| 66 |
+
X = np.asarray(X, dtype=np.float32)
|
| 67 |
+
if X.shape[1] != 9:
|
| 68 |
+
raise ValueError(f"Expected 9 features, got {X.shape[1]}")
|
| 69 |
+
|
| 70 |
+
self._X = X
|
| 71 |
+
self._metadata = metadata.copy()
|
| 72 |
+
|
| 73 |
+
if self.normalize:
|
| 74 |
+
self._scaler = StandardScaler()
|
| 75 |
+
self._X_normalized = self._scaler.fit_transform(X).astype(np.float32)
|
| 76 |
+
else:
|
| 77 |
+
self._X_normalized = X
|
| 78 |
+
|
| 79 |
+
if self.method == "knn":
|
| 80 |
+
# Build KNN index
|
| 81 |
+
n_neighbors = min(self.n_neighbors, X.shape[0])
|
| 82 |
+
self._knn = NearestNeighbors(
|
| 83 |
+
n_neighbors=n_neighbors,
|
| 84 |
+
metric="euclidean",
|
| 85 |
+
algorithm="auto",
|
| 86 |
+
)
|
| 87 |
+
self._knn.fit(self._X_normalized)
|
| 88 |
+
|
| 89 |
+
self.is_fitted = True
|
| 90 |
+
return self
|
| 91 |
+
|
| 92 |
+
def recommend(
|
| 93 |
+
self, preferences: np.ndarray, k: int = 5
|
| 94 |
+
) -> list[dict[str, Any]]:
|
| 95 |
+
"""Find coffees most similar to user preferences.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
preferences: User taste preferences of shape (9,).
|
| 99 |
+
k: Number of recommendations.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of k recommendation dictionaries.
|
| 103 |
+
"""
|
| 104 |
+
self._validate_fitted()
|
| 105 |
+
preferences = self._validate_preferences(preferences)
|
| 106 |
+
|
| 107 |
+
n_samples = self._X.shape[0]
|
| 108 |
+
k = min(k, n_samples)
|
| 109 |
+
|
| 110 |
+
# Normalize preferences if needed
|
| 111 |
+
if self.normalize:
|
| 112 |
+
pref_normalized = self._scaler.transform(
|
| 113 |
+
preferences.reshape(1, -1)
|
| 114 |
+
).astype(np.float32)
|
| 115 |
+
else:
|
| 116 |
+
pref_normalized = preferences.reshape(1, -1)
|
| 117 |
+
|
| 118 |
+
if self.method == "knn":
|
| 119 |
+
indices, scores = self._recommend_knn(pref_normalized, k)
|
| 120 |
+
else:
|
| 121 |
+
indices, scores = self._recommend_cosine(pref_normalized, k)
|
| 122 |
+
|
| 123 |
+
recommendations = []
|
| 124 |
+
for idx, score in zip(indices, scores):
|
| 125 |
+
rec = self._format_recommendation(idx, score, self._X[idx])
|
| 126 |
+
recommendations.append(rec)
|
| 127 |
+
|
| 128 |
+
return recommendations
|
| 129 |
+
|
| 130 |
+
def _recommend_knn(
|
| 131 |
+
self, pref_normalized: np.ndarray, k: int
|
| 132 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 133 |
+
"""Find k nearest neighbors using KNN index."""
|
| 134 |
+
k = min(k, self._knn.n_neighbors)
|
| 135 |
+
distances, indices = self._knn.kneighbors(pref_normalized, n_neighbors=k)
|
| 136 |
+
|
| 137 |
+
# Convert distance to similarity (higher is better)
|
| 138 |
+
# Using inverse distance with offset
|
| 139 |
+
scores = 1.0 / (1.0 + distances.squeeze())
|
| 140 |
+
|
| 141 |
+
return indices.squeeze(), scores
|
| 142 |
+
|
| 143 |
+
def _recommend_cosine(
|
| 144 |
+
self, pref_normalized: np.ndarray, k: int
|
| 145 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 146 |
+
"""Rank all coffees by cosine similarity."""
|
| 147 |
+
# Compute cosine similarity
|
| 148 |
+
pref_norm = pref_normalized / (
|
| 149 |
+
np.linalg.norm(pref_normalized) + 1e-8
|
| 150 |
+
)
|
| 151 |
+
X_norms = np.linalg.norm(self._X_normalized, axis=1, keepdims=True) + 1e-8
|
| 152 |
+
X_normalized_unit = self._X_normalized / X_norms
|
| 153 |
+
|
| 154 |
+
similarities = (X_normalized_unit @ pref_norm.T).squeeze()
|
| 155 |
+
|
| 156 |
+
# Get top k
|
| 157 |
+
indices = np.argsort(similarities)[::-1][:k]
|
| 158 |
+
scores = similarities[indices]
|
| 159 |
+
|
| 160 |
+
# Shift to [0, 1] range (cosine similarity is in [-1, 1])
|
| 161 |
+
scores = (scores + 1.0) / 2.0
|
| 162 |
+
|
| 163 |
+
return indices, scores
|
| 164 |
+
|
| 165 |
+
def save(self, path: str | Path) -> None:
|
| 166 |
+
"""Save the fitted model to disk using pickle.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
path: File path to save the model to.
|
| 170 |
+
"""
|
| 171 |
+
self._validate_fitted()
|
| 172 |
+
path = Path(path)
|
| 173 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
state = {
|
| 176 |
+
"method": self.method,
|
| 177 |
+
"n_neighbors": self.n_neighbors,
|
| 178 |
+
"normalize": self.normalize,
|
| 179 |
+
"X": self._X,
|
| 180 |
+
"X_normalized": self._X_normalized,
|
| 181 |
+
"metadata": self._metadata,
|
| 182 |
+
"scaler": self._scaler,
|
| 183 |
+
"knn": self._knn,
|
| 184 |
+
}
|
| 185 |
+
with open(path, "wb") as f:
|
| 186 |
+
pickle.dump(state, f)
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def load(cls, path: str | Path) -> "ClassicalMLRecommender":
|
| 190 |
+
"""Load a fitted model from disk.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
path: File path to load the model from.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
The loaded recommender instance.
|
| 197 |
+
"""
|
| 198 |
+
with open(path, "rb") as f:
|
| 199 |
+
state = pickle.load(f)
|
| 200 |
+
|
| 201 |
+
model = cls(
|
| 202 |
+
method=state["method"],
|
| 203 |
+
n_neighbors=state["n_neighbors"],
|
| 204 |
+
normalize=state["normalize"],
|
| 205 |
+
)
|
| 206 |
+
model._X = state["X"]
|
| 207 |
+
model._X_normalized = state["X_normalized"]
|
| 208 |
+
model._metadata = state["metadata"]
|
| 209 |
+
model._scaler = state["scaler"]
|
| 210 |
+
model._knn = state["knn"]
|
| 211 |
+
model.is_fitted = True
|
| 212 |
+
return model
|
src/brewmatch/models/neural.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Neural network recommender using learned coffee embeddings."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
|
| 13 |
+
from .base import BaseRecommender
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TasteEncoder(nn.Module):
|
| 17 |
+
"""Neural network that encodes taste profiles into embeddings.
|
| 18 |
+
|
| 19 |
+
Architecture: MLP with residual connections that maps 9 taste features
|
| 20 |
+
to a lower-dimensional embedding space.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
input_dim: int = 9,
|
| 26 |
+
hidden_dim: int = 64,
|
| 27 |
+
embedding_dim: int = 32,
|
| 28 |
+
dropout: float = 0.1,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.input_proj = nn.Linear(input_dim, hidden_dim)
|
| 33 |
+
self.hidden1 = nn.Linear(hidden_dim, hidden_dim)
|
| 34 |
+
self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
|
| 35 |
+
self.output_proj = nn.Linear(hidden_dim, embedding_dim)
|
| 36 |
+
|
| 37 |
+
self.norm1 = nn.LayerNorm(hidden_dim)
|
| 38 |
+
self.norm2 = nn.LayerNorm(hidden_dim)
|
| 39 |
+
self.dropout = nn.Dropout(dropout)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""Encode taste profiles to embeddings.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Taste features of shape (batch_size, 9).
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Embeddings of shape (batch_size, embedding_dim).
|
| 49 |
+
"""
|
| 50 |
+
# Project to hidden dimension
|
| 51 |
+
h = F.gelu(self.input_proj(x))
|
| 52 |
+
|
| 53 |
+
# Residual block 1
|
| 54 |
+
residual = h
|
| 55 |
+
h = self.norm1(h)
|
| 56 |
+
h = F.gelu(self.hidden1(h))
|
| 57 |
+
h = self.dropout(h)
|
| 58 |
+
h = h + residual
|
| 59 |
+
|
| 60 |
+
# Residual block 2
|
| 61 |
+
residual = h
|
| 62 |
+
h = self.norm2(h)
|
| 63 |
+
h = F.gelu(self.hidden2(h))
|
| 64 |
+
h = self.dropout(h)
|
| 65 |
+
h = h + residual
|
| 66 |
+
|
| 67 |
+
# Project to embedding space
|
| 68 |
+
embedding = self.output_proj(h)
|
| 69 |
+
|
| 70 |
+
# L2 normalize for cosine similarity
|
| 71 |
+
embedding = F.normalize(embedding, p=2, dim=-1)
|
| 72 |
+
|
| 73 |
+
return embedding
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TripletDataset(Dataset):
|
| 77 |
+
"""Dataset that generates triplets for contrastive learning.
|
| 78 |
+
|
| 79 |
+
For each anchor, samples a positive (similar coffee) and negative
|
| 80 |
+
(dissimilar coffee) based on taste profile distance.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
X: np.ndarray,
|
| 86 |
+
margin_quantile: float = 0.3,
|
| 87 |
+
) -> None:
|
| 88 |
+
"""Initialize triplet dataset.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
X: Feature matrix of shape (n_samples, 9).
|
| 92 |
+
margin_quantile: Quantile for positive/negative threshold.
|
| 93 |
+
Coffees within this distance quantile are positives.
|
| 94 |
+
"""
|
| 95 |
+
self.X = torch.tensor(X, dtype=torch.float32)
|
| 96 |
+
self.n_samples = X.shape[0]
|
| 97 |
+
|
| 98 |
+
# Precompute pairwise distances
|
| 99 |
+
X_tensor = self.X
|
| 100 |
+
self.distances = torch.cdist(X_tensor, X_tensor, p=2)
|
| 101 |
+
|
| 102 |
+
# Determine threshold for positive/negative
|
| 103 |
+
flat_distances = self.distances.flatten()
|
| 104 |
+
self.positive_threshold = torch.quantile(flat_distances, margin_quantile)
|
| 105 |
+
|
| 106 |
+
def __len__(self) -> int:
|
| 107 |
+
return self.n_samples
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 110 |
+
"""Get a triplet (anchor, positive, negative).
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
idx: Anchor index.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Tuple of (anchor, positive, negative) taste profiles.
|
| 117 |
+
"""
|
| 118 |
+
anchor = self.X[idx]
|
| 119 |
+
|
| 120 |
+
# Get distances from anchor
|
| 121 |
+
dists = self.distances[idx]
|
| 122 |
+
|
| 123 |
+
# Find positives (close) and negatives (far), excluding self
|
| 124 |
+
mask = torch.arange(self.n_samples) != idx
|
| 125 |
+
positive_mask = mask & (dists <= self.positive_threshold)
|
| 126 |
+
negative_mask = mask & (dists > self.positive_threshold)
|
| 127 |
+
|
| 128 |
+
# Handle edge cases
|
| 129 |
+
if positive_mask.sum() == 0:
|
| 130 |
+
# Use closest non-self sample
|
| 131 |
+
dists_masked = dists.clone()
|
| 132 |
+
dists_masked[idx] = float("inf")
|
| 133 |
+
positive_idx = dists_masked.argmin().item()
|
| 134 |
+
else:
|
| 135 |
+
positive_indices = torch.where(positive_mask)[0]
|
| 136 |
+
positive_idx = positive_indices[
|
| 137 |
+
torch.randint(len(positive_indices), (1,))
|
| 138 |
+
].item()
|
| 139 |
+
|
| 140 |
+
if negative_mask.sum() == 0:
|
| 141 |
+
# Use farthest sample
|
| 142 |
+
negative_idx = dists.argmax().item()
|
| 143 |
+
else:
|
| 144 |
+
negative_indices = torch.where(negative_mask)[0]
|
| 145 |
+
negative_idx = negative_indices[
|
| 146 |
+
torch.randint(len(negative_indices), (1,))
|
| 147 |
+
].item()
|
| 148 |
+
|
| 149 |
+
return anchor, self.X[positive_idx], self.X[negative_idx]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class NeuralRecommender(BaseRecommender):
|
| 153 |
+
"""Neural recommender using learned coffee embeddings.
|
| 154 |
+
|
| 155 |
+
Uses contrastive learning with triplet loss to learn embeddings
|
| 156 |
+
that capture taste similarity. Similar coffees have nearby embeddings.
|
| 157 |
+
|
| 158 |
+
Attributes:
|
| 159 |
+
embedding_dim: Dimension of learned embeddings.
|
| 160 |
+
hidden_dim: Hidden layer dimension in encoder.
|
| 161 |
+
learning_rate: Learning rate for training.
|
| 162 |
+
margin: Triplet loss margin.
|
| 163 |
+
device: Torch device (cuda/cpu).
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
embedding_dim: int = 32,
|
| 169 |
+
hidden_dim: int = 64,
|
| 170 |
+
learning_rate: float = 1e-3,
|
| 171 |
+
margin: float = 0.5,
|
| 172 |
+
device: str | None = None,
|
| 173 |
+
) -> None:
|
| 174 |
+
"""Initialize the neural recommender.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
embedding_dim: Embedding dimension.
|
| 178 |
+
hidden_dim: Hidden layer dimension.
|
| 179 |
+
learning_rate: Learning rate.
|
| 180 |
+
margin: Triplet loss margin.
|
| 181 |
+
device: Torch device. Auto-detects CUDA if not specified.
|
| 182 |
+
"""
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.embedding_dim = embedding_dim
|
| 185 |
+
self.hidden_dim = hidden_dim
|
| 186 |
+
self.learning_rate = learning_rate
|
| 187 |
+
self.margin = margin
|
| 188 |
+
|
| 189 |
+
if device is None:
|
| 190 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 191 |
+
else:
|
| 192 |
+
self.device = device
|
| 193 |
+
|
| 194 |
+
self._encoder: TasteEncoder | None = None
|
| 195 |
+
self._X: np.ndarray | None = None
|
| 196 |
+
self._embeddings: np.ndarray | None = None
|
| 197 |
+
self._feature_mean: np.ndarray | None = None
|
| 198 |
+
self._feature_std: np.ndarray | None = None
|
| 199 |
+
|
| 200 |
+
def fit(
|
| 201 |
+
self,
|
| 202 |
+
X: np.ndarray,
|
| 203 |
+
metadata: pd.DataFrame,
|
| 204 |
+
epochs: int = 100,
|
| 205 |
+
batch_size: int = 64,
|
| 206 |
+
verbose: bool = True,
|
| 207 |
+
) -> "NeuralRecommender":
|
| 208 |
+
"""Fit the neural recommender using contrastive learning.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
X: Feature matrix of shape (n_samples, 9).
|
| 212 |
+
metadata: DataFrame with coffee metadata.
|
| 213 |
+
epochs: Number of training epochs.
|
| 214 |
+
batch_size: Training batch size.
|
| 215 |
+
verbose: Whether to print training progress.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
self: The fitted recommender.
|
| 219 |
+
"""
|
| 220 |
+
X = np.asarray(X, dtype=np.float32)
|
| 221 |
+
if X.shape[1] != 9:
|
| 222 |
+
raise ValueError(f"Expected 9 features, got {X.shape[1]}")
|
| 223 |
+
|
| 224 |
+
self._X = X
|
| 225 |
+
self._metadata = metadata.copy()
|
| 226 |
+
|
| 227 |
+
# Normalize features
|
| 228 |
+
self._feature_mean = X.mean(axis=0)
|
| 229 |
+
self._feature_std = X.std(axis=0) + 1e-8
|
| 230 |
+
X_normalized = (X - self._feature_mean) / self._feature_std
|
| 231 |
+
|
| 232 |
+
# Create encoder
|
| 233 |
+
self._encoder = TasteEncoder(
|
| 234 |
+
input_dim=9,
|
| 235 |
+
hidden_dim=self.hidden_dim,
|
| 236 |
+
embedding_dim=self.embedding_dim,
|
| 237 |
+
).to(self.device)
|
| 238 |
+
|
| 239 |
+
# Create dataset and dataloader
|
| 240 |
+
dataset = TripletDataset(X_normalized)
|
| 241 |
+
dataloader = DataLoader(
|
| 242 |
+
dataset,
|
| 243 |
+
batch_size=batch_size,
|
| 244 |
+
shuffle=True,
|
| 245 |
+
drop_last=False,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Training
|
| 249 |
+
optimizer = torch.optim.AdamW(
|
| 250 |
+
self._encoder.parameters(),
|
| 251 |
+
lr=self.learning_rate,
|
| 252 |
+
weight_decay=0.01,
|
| 253 |
+
)
|
| 254 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 255 |
+
optimizer, T_max=epochs
|
| 256 |
+
)
|
| 257 |
+
triplet_loss = nn.TripletMarginLoss(margin=self.margin, p=2)
|
| 258 |
+
|
| 259 |
+
self._encoder.train()
|
| 260 |
+
for epoch in range(epochs):
|
| 261 |
+
total_loss = 0.0
|
| 262 |
+
n_batches = 0
|
| 263 |
+
|
| 264 |
+
for anchor, positive, negative in dataloader:
|
| 265 |
+
anchor = anchor.to(self.device)
|
| 266 |
+
positive = positive.to(self.device)
|
| 267 |
+
negative = negative.to(self.device)
|
| 268 |
+
|
| 269 |
+
optimizer.zero_grad()
|
| 270 |
+
|
| 271 |
+
anchor_emb = self._encoder(anchor)
|
| 272 |
+
positive_emb = self._encoder(positive)
|
| 273 |
+
negative_emb = self._encoder(negative)
|
| 274 |
+
|
| 275 |
+
loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
|
| 276 |
+
loss.backward()
|
| 277 |
+
optimizer.step()
|
| 278 |
+
|
| 279 |
+
total_loss += loss.item()
|
| 280 |
+
n_batches += 1
|
| 281 |
+
|
| 282 |
+
scheduler.step()
|
| 283 |
+
|
| 284 |
+
if verbose and (epoch + 1) % 10 == 0:
|
| 285 |
+
avg_loss = total_loss / n_batches
|
| 286 |
+
print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
|
| 287 |
+
|
| 288 |
+
# Compute embeddings for all coffees
|
| 289 |
+
self._encoder.eval()
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
X_tensor = torch.tensor(X_normalized, dtype=torch.float32).to(self.device)
|
| 292 |
+
self._embeddings = self._encoder(X_tensor).cpu().numpy()
|
| 293 |
+
|
| 294 |
+
self.is_fitted = True
|
| 295 |
+
return self
|
| 296 |
+
|
| 297 |
+
def recommend(
|
| 298 |
+
self, preferences: np.ndarray, k: int = 5
|
| 299 |
+
) -> list[dict[str, Any]]:
|
| 300 |
+
"""Find coffees with embeddings closest to user preferences.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
preferences: User taste preferences of shape (9,).
|
| 304 |
+
k: Number of recommendations.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
List of k recommendation dictionaries.
|
| 308 |
+
"""
|
| 309 |
+
self._validate_fitted()
|
| 310 |
+
preferences = self._validate_preferences(preferences)
|
| 311 |
+
|
| 312 |
+
n_samples = self._X.shape[0]
|
| 313 |
+
k = min(k, n_samples)
|
| 314 |
+
|
| 315 |
+
# Normalize and encode preferences
|
| 316 |
+
pref_normalized = (preferences - self._feature_mean) / self._feature_std
|
| 317 |
+
pref_tensor = torch.tensor(
|
| 318 |
+
pref_normalized, dtype=torch.float32
|
| 319 |
+
).unsqueeze(0).to(self.device)
|
| 320 |
+
|
| 321 |
+
self._encoder.eval()
|
| 322 |
+
with torch.no_grad():
|
| 323 |
+
pref_embedding = self._encoder(pref_tensor).cpu().numpy()
|
| 324 |
+
|
| 325 |
+
# Find nearest embeddings using cosine similarity
|
| 326 |
+
# (embeddings are already L2 normalized)
|
| 327 |
+
similarities = (self._embeddings @ pref_embedding.T).squeeze()
|
| 328 |
+
|
| 329 |
+
# Get top k
|
| 330 |
+
indices = np.argsort(similarities)[::-1][:k]
|
| 331 |
+
scores = similarities[indices]
|
| 332 |
+
|
| 333 |
+
# Shift to [0, 1] range
|
| 334 |
+
scores = (scores + 1.0) / 2.0
|
| 335 |
+
|
| 336 |
+
recommendations = []
|
| 337 |
+
for idx, score in zip(indices, scores):
|
| 338 |
+
rec = self._format_recommendation(idx, score, self._X[idx])
|
| 339 |
+
recommendations.append(rec)
|
| 340 |
+
|
| 341 |
+
return recommendations
|
| 342 |
+
|
| 343 |
+
def get_embedding(self, preferences: np.ndarray) -> np.ndarray:
|
| 344 |
+
"""Get the embedding for a taste profile.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
preferences: Taste preferences of shape (9,) or (n, 9).
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
Embedding(s) of shape (embedding_dim,) or (n, embedding_dim).
|
| 351 |
+
"""
|
| 352 |
+
self._validate_fitted()
|
| 353 |
+
|
| 354 |
+
preferences = np.asarray(preferences, dtype=np.float32)
|
| 355 |
+
squeeze_output = False
|
| 356 |
+
if preferences.ndim == 1:
|
| 357 |
+
preferences = preferences.reshape(1, -1)
|
| 358 |
+
squeeze_output = True
|
| 359 |
+
|
| 360 |
+
pref_normalized = (preferences - self._feature_mean) / self._feature_std
|
| 361 |
+
pref_tensor = torch.tensor(pref_normalized, dtype=torch.float32).to(self.device)
|
| 362 |
+
|
| 363 |
+
self._encoder.eval()
|
| 364 |
+
with torch.no_grad():
|
| 365 |
+
embeddings = self._encoder(pref_tensor).cpu().numpy()
|
| 366 |
+
|
| 367 |
+
if squeeze_output:
|
| 368 |
+
return embeddings.squeeze(0)
|
| 369 |
+
return embeddings
|
| 370 |
+
|
| 371 |
+
def save(self, path: str | Path) -> None:
|
| 372 |
+
"""Save the fitted model to disk.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
path: File path to save the model to.
|
| 376 |
+
"""
|
| 377 |
+
self._validate_fitted()
|
| 378 |
+
path = Path(path)
|
| 379 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 380 |
+
|
| 381 |
+
state = {
|
| 382 |
+
"embedding_dim": self.embedding_dim,
|
| 383 |
+
"hidden_dim": self.hidden_dim,
|
| 384 |
+
"learning_rate": self.learning_rate,
|
| 385 |
+
"margin": self.margin,
|
| 386 |
+
"encoder_state_dict": self._encoder.state_dict(),
|
| 387 |
+
"X": self._X,
|
| 388 |
+
"metadata": self._metadata,
|
| 389 |
+
"embeddings": self._embeddings,
|
| 390 |
+
"feature_mean": self._feature_mean,
|
| 391 |
+
"feature_std": self._feature_std,
|
| 392 |
+
}
|
| 393 |
+
torch.save(state, path)
|
| 394 |
+
|
| 395 |
+
@classmethod
|
| 396 |
+
def load(cls, path: str | Path, device: str | None = None) -> "NeuralRecommender":
|
| 397 |
+
"""Load a fitted model from disk.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
path: File path to load the model from.
|
| 401 |
+
device: Torch device. Auto-detects if not specified.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
The loaded recommender instance.
|
| 405 |
+
"""
|
| 406 |
+
state = torch.load(path, map_location="cpu", weights_only=False)
|
| 407 |
+
|
| 408 |
+
model = cls(
|
| 409 |
+
embedding_dim=state["embedding_dim"],
|
| 410 |
+
hidden_dim=state["hidden_dim"],
|
| 411 |
+
learning_rate=state["learning_rate"],
|
| 412 |
+
margin=state["margin"],
|
| 413 |
+
device=device,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
model._encoder = TasteEncoder(
|
| 417 |
+
input_dim=9,
|
| 418 |
+
hidden_dim=state["hidden_dim"],
|
| 419 |
+
embedding_dim=state["embedding_dim"],
|
| 420 |
+
).to(model.device)
|
| 421 |
+
model._encoder.load_state_dict(state["encoder_state_dict"])
|
| 422 |
+
model._encoder.eval()
|
| 423 |
+
|
| 424 |
+
model._X = state["X"]
|
| 425 |
+
model._metadata = state["metadata"]
|
| 426 |
+
model._embeddings = state["embeddings"]
|
| 427 |
+
model._feature_mean = state["feature_mean"]
|
| 428 |
+
model._feature_std = state["feature_std"]
|
| 429 |
+
model.is_fitted = True
|
| 430 |
+
|
| 431 |
+
return model
|
src/brewmatch/train.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training script for all BrewMatch models.
|
| 2 |
+
|
| 3 |
+
Supports hyperparameter tuning with Optuna:
|
| 4 |
+
- `uv run train` - Train with defaults or previously tuned hyperparameters
|
| 5 |
+
- `uv run train --tune` - Run Optuna tuning, save params, then train
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import optuna
|
| 15 |
+
from optuna.samplers import TPESampler
|
| 16 |
+
from optuna.pruners import MedianPruner
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
from brewmatch.config import (
|
| 20 |
+
CHECKPOINTS_DIR,
|
| 21 |
+
PROJECT_ROOT,
|
| 22 |
+
TASTE_FEATURES,
|
| 23 |
+
)
|
| 24 |
+
from brewmatch.data import load_processed_data
|
| 25 |
+
from brewmatch.device import get_device, print_device_info
|
| 26 |
+
from brewmatch.models import (
|
| 27 |
+
NaiveBaselineRecommender,
|
| 28 |
+
ClassicalMLRecommender,
|
| 29 |
+
NeuralRecommender,
|
| 30 |
+
)
|
| 31 |
+
from brewmatch.evaluation import evaluate_model
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Where tuned hyperparameters are saved
|
| 35 |
+
HYPERPARAMS_FILE = CHECKPOINTS_DIR / "hyperparameters.json"
|
| 36 |
+
|
| 37 |
+
# Default hyperparameters (used if no tuning has been done)
|
| 38 |
+
DEFAULT_NEURAL_PARAMS = {
|
| 39 |
+
"embedding_dim": 32,
|
| 40 |
+
"hidden_dim": 64,
|
| 41 |
+
"learning_rate": 0.001,
|
| 42 |
+
"margin": 0.5,
|
| 43 |
+
"batch_size": 32,
|
| 44 |
+
"epochs": 100,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
DEFAULT_CLASSICAL_PARAMS = {
|
| 48 |
+
"method": "knn",
|
| 49 |
+
"n_neighbors": 50,
|
| 50 |
+
"normalize": True,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_hyperparameters() -> dict[str, Any]:
|
| 55 |
+
"""Load saved hyperparameters if they exist."""
|
| 56 |
+
if HYPERPARAMS_FILE.exists():
|
| 57 |
+
with open(HYPERPARAMS_FILE) as f:
|
| 58 |
+
return json.load(f)
|
| 59 |
+
return {}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def save_hyperparameters(params: dict[str, Any]) -> None:
|
| 63 |
+
"""Save hyperparameters for future runs."""
|
| 64 |
+
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
with open(HYPERPARAMS_FILE, "w") as f:
|
| 66 |
+
json.dump(params, f, indent=2)
|
| 67 |
+
print(f"Hyperparameters saved to {HYPERPARAMS_FILE}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_neural_params(saved: dict[str, Any]) -> dict[str, Any]:
|
| 71 |
+
"""Get neural network params (saved or defaults)."""
|
| 72 |
+
if "neural" in saved:
|
| 73 |
+
print("Using tuned neural hyperparameters")
|
| 74 |
+
return {**DEFAULT_NEURAL_PARAMS, **saved["neural"]}
|
| 75 |
+
print("Using default neural hyperparameters")
|
| 76 |
+
return DEFAULT_NEURAL_PARAMS.copy()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_classical_params(saved: dict[str, Any]) -> dict[str, Any]:
|
| 80 |
+
"""Get classical ML params (saved or defaults)."""
|
| 81 |
+
if "classical" in saved:
|
| 82 |
+
print("Using tuned classical hyperparameters")
|
| 83 |
+
return {**DEFAULT_CLASSICAL_PARAMS, **saved["classical"]}
|
| 84 |
+
print("Using default classical hyperparameters")
|
| 85 |
+
return DEFAULT_CLASSICAL_PARAMS.copy()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# =============================================================================
|
| 89 |
+
# Training Functions
|
| 90 |
+
# =============================================================================
|
| 91 |
+
|
| 92 |
+
def train_baseline(train_df: pd.DataFrame) -> NaiveBaselineRecommender:
|
| 93 |
+
"""Train the naive baseline model."""
|
| 94 |
+
print("Training Naive Baseline Model...")
|
| 95 |
+
|
| 96 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 97 |
+
model = NaiveBaselineRecommender(strategy="mean")
|
| 98 |
+
model.fit(X_train, train_df)
|
| 99 |
+
|
| 100 |
+
print(f" Strategy: {model.strategy}")
|
| 101 |
+
print(f" Coffees indexed: {len(model._X)}")
|
| 102 |
+
|
| 103 |
+
return model
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def train_classical(train_df: pd.DataFrame, params: dict[str, Any]) -> ClassicalMLRecommender:
|
| 107 |
+
"""Train the classical ML model with given hyperparameters."""
|
| 108 |
+
print("Training Classical ML Model...")
|
| 109 |
+
print(f" Params: {params}")
|
| 110 |
+
|
| 111 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 112 |
+
model = ClassicalMLRecommender(
|
| 113 |
+
method=params["method"],
|
| 114 |
+
n_neighbors=params["n_neighbors"],
|
| 115 |
+
normalize=params["normalize"],
|
| 116 |
+
)
|
| 117 |
+
model.fit(X_train, train_df)
|
| 118 |
+
|
| 119 |
+
print(f" Coffees indexed: {len(model._X)}")
|
| 120 |
+
|
| 121 |
+
return model
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def train_neural(
|
| 125 |
+
train_df: pd.DataFrame,
|
| 126 |
+
params: dict[str, Any],
|
| 127 |
+
device: str,
|
| 128 |
+
) -> NeuralRecommender:
|
| 129 |
+
"""Train the neural network model with given hyperparameters."""
|
| 130 |
+
print("Training Neural Network Model...")
|
| 131 |
+
print(f" Params: {params}")
|
| 132 |
+
|
| 133 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 134 |
+
|
| 135 |
+
model = NeuralRecommender(
|
| 136 |
+
embedding_dim=params["embedding_dim"],
|
| 137 |
+
hidden_dim=params["hidden_dim"],
|
| 138 |
+
learning_rate=params["learning_rate"],
|
| 139 |
+
margin=params["margin"],
|
| 140 |
+
device=device,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
model.fit(
|
| 144 |
+
X=X_train,
|
| 145 |
+
metadata=train_df,
|
| 146 |
+
epochs=params.get("epochs", 100),
|
| 147 |
+
batch_size=params["batch_size"],
|
| 148 |
+
verbose=True,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return model
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def save_models(
|
| 155 |
+
baseline: NaiveBaselineRecommender | None,
|
| 156 |
+
classical: ClassicalMLRecommender | None,
|
| 157 |
+
neural: NeuralRecommender | None,
|
| 158 |
+
params: dict[str, Any],
|
| 159 |
+
) -> None:
|
| 160 |
+
"""Save all trained models."""
|
| 161 |
+
print("\nSaving models...")
|
| 162 |
+
|
| 163 |
+
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
if baseline:
|
| 166 |
+
baseline.save(CHECKPOINTS_DIR / "baseline.pkl")
|
| 167 |
+
print(f" Baseline: {CHECKPOINTS_DIR / 'baseline.pkl'}")
|
| 168 |
+
|
| 169 |
+
if classical:
|
| 170 |
+
classical.save(CHECKPOINTS_DIR / "classical.pkl")
|
| 171 |
+
print(f" Classical: {CHECKPOINTS_DIR / 'classical.pkl'}")
|
| 172 |
+
|
| 173 |
+
if neural:
|
| 174 |
+
neural.save(CHECKPOINTS_DIR / "neural.pt")
|
| 175 |
+
print(f" Neural: {CHECKPOINTS_DIR / 'neural.pt'}")
|
| 176 |
+
|
| 177 |
+
# Save model metadata
|
| 178 |
+
model_info = {
|
| 179 |
+
"models": ["baseline", "classical", "neural"],
|
| 180 |
+
"taste_features": TASTE_FEATURES,
|
| 181 |
+
"hyperparameters": params,
|
| 182 |
+
}
|
| 183 |
+
with open(CHECKPOINTS_DIR / "model_info.json", "w") as f:
|
| 184 |
+
json.dump(model_info, f, indent=2)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# =============================================================================
|
| 188 |
+
# Optuna Hyperparameter Tuning
|
| 189 |
+
# =============================================================================
|
| 190 |
+
|
| 191 |
+
def create_cv_splits(
|
| 192 |
+
df: pd.DataFrame,
|
| 193 |
+
n_folds: int = 3,
|
| 194 |
+
seed: int = 42,
|
| 195 |
+
) -> list[tuple[pd.DataFrame, pd.DataFrame]]:
|
| 196 |
+
"""Create cross-validation splits."""
|
| 197 |
+
np.random.seed(seed)
|
| 198 |
+
indices = np.random.permutation(len(df))
|
| 199 |
+
fold_size = len(df) // n_folds
|
| 200 |
+
|
| 201 |
+
splits = []
|
| 202 |
+
for i in range(n_folds):
|
| 203 |
+
start = i * fold_size
|
| 204 |
+
end = start + fold_size if i < n_folds - 1 else len(df)
|
| 205 |
+
|
| 206 |
+
val_idx = indices[start:end]
|
| 207 |
+
train_idx = np.concatenate([indices[:start], indices[end:]])
|
| 208 |
+
|
| 209 |
+
splits.append((
|
| 210 |
+
df.iloc[train_idx].reset_index(drop=True),
|
| 211 |
+
df.iloc[val_idx].reset_index(drop=True),
|
| 212 |
+
))
|
| 213 |
+
|
| 214 |
+
return splits
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def tune_neural(
|
| 218 |
+
train_df: pd.DataFrame,
|
| 219 |
+
device: str,
|
| 220 |
+
n_trials: int = 50,
|
| 221 |
+
n_folds: int = 3,
|
| 222 |
+
) -> dict[str, Any]:
|
| 223 |
+
"""Tune neural network hyperparameters with Optuna."""
|
| 224 |
+
print(f"\n{'='*60}")
|
| 225 |
+
print("TUNING NEURAL NETWORK HYPERPARAMETERS")
|
| 226 |
+
print(f"{'='*60}")
|
| 227 |
+
print(f"Trials: {n_trials}, CV Folds: {n_folds}")
|
| 228 |
+
|
| 229 |
+
splits = create_cv_splits(train_df, n_folds)
|
| 230 |
+
|
| 231 |
+
def objective(trial: optuna.Trial) -> float:
|
| 232 |
+
params = {
|
| 233 |
+
"embedding_dim": trial.suggest_int("embedding_dim", 16, 128, step=16),
|
| 234 |
+
"hidden_dim": trial.suggest_int("hidden_dim", 32, 256, step=32),
|
| 235 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
|
| 236 |
+
"margin": trial.suggest_float("margin", 0.1, 1.0),
|
| 237 |
+
"batch_size": trial.suggest_categorical("batch_size", [16, 32, 64, 128]),
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
scores = []
|
| 241 |
+
for fold_idx, (fold_train, fold_val) in enumerate(splits):
|
| 242 |
+
model = NeuralRecommender(
|
| 243 |
+
embedding_dim=params["embedding_dim"],
|
| 244 |
+
hidden_dim=params["hidden_dim"],
|
| 245 |
+
learning_rate=params["learning_rate"],
|
| 246 |
+
margin=params["margin"],
|
| 247 |
+
device=device,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
model.fit(
|
| 251 |
+
X=fold_train[TASTE_FEATURES].values,
|
| 252 |
+
metadata=fold_train,
|
| 253 |
+
epochs=30, # Reduced for tuning speed
|
| 254 |
+
batch_size=params["batch_size"],
|
| 255 |
+
verbose=False,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
metrics = evaluate_model(
|
| 259 |
+
model,
|
| 260 |
+
{"X": fold_val[TASTE_FEATURES].values, "metadata": fold_val},
|
| 261 |
+
k_values=[5],
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
score = metrics.get("precision@k", {}).get(5, 0.0)
|
| 265 |
+
scores.append(score)
|
| 266 |
+
|
| 267 |
+
trial.report(np.mean(scores), fold_idx)
|
| 268 |
+
if trial.should_prune():
|
| 269 |
+
raise optuna.TrialPruned()
|
| 270 |
+
|
| 271 |
+
return np.mean(scores)
|
| 272 |
+
|
| 273 |
+
study = optuna.create_study(
|
| 274 |
+
direction="maximize",
|
| 275 |
+
sampler=TPESampler(seed=42),
|
| 276 |
+
pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=1),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Suppress Optuna logging
|
| 280 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 281 |
+
|
| 282 |
+
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
| 283 |
+
|
| 284 |
+
print(f"\nBest Precision@5: {study.best_value:.4f}")
|
| 285 |
+
print(f"Best params: {study.best_params}")
|
| 286 |
+
|
| 287 |
+
return study.best_params
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def tune_classical(
|
| 291 |
+
train_df: pd.DataFrame,
|
| 292 |
+
n_trials: int = 30,
|
| 293 |
+
n_folds: int = 3,
|
| 294 |
+
) -> dict[str, Any]:
|
| 295 |
+
"""Tune classical ML hyperparameters with Optuna."""
|
| 296 |
+
print(f"\n{'='*60}")
|
| 297 |
+
print("TUNING CLASSICAL ML HYPERPARAMETERS")
|
| 298 |
+
print(f"{'='*60}")
|
| 299 |
+
print(f"Trials: {n_trials}, CV Folds: {n_folds}")
|
| 300 |
+
|
| 301 |
+
splits = create_cv_splits(train_df, n_folds)
|
| 302 |
+
|
| 303 |
+
def objective(trial: optuna.Trial) -> float:
|
| 304 |
+
params = {
|
| 305 |
+
"method": trial.suggest_categorical("method", ["knn", "cosine"]),
|
| 306 |
+
"n_neighbors": trial.suggest_int("n_neighbors", 5, 100),
|
| 307 |
+
"normalize": trial.suggest_categorical("normalize", [True, False]),
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
scores = []
|
| 311 |
+
for fold_train, fold_val in splits:
|
| 312 |
+
model = ClassicalMLRecommender(**params)
|
| 313 |
+
model.fit(fold_train[TASTE_FEATURES].values, fold_train)
|
| 314 |
+
|
| 315 |
+
metrics = evaluate_model(
|
| 316 |
+
model,
|
| 317 |
+
{"X": fold_val[TASTE_FEATURES].values, "metadata": fold_val},
|
| 318 |
+
k_values=[5],
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
scores.append(metrics.get("precision@k", {}).get(5, 0.0))
|
| 322 |
+
|
| 323 |
+
return np.mean(scores)
|
| 324 |
+
|
| 325 |
+
study = optuna.create_study(
|
| 326 |
+
direction="maximize",
|
| 327 |
+
sampler=TPESampler(seed=42),
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 331 |
+
|
| 332 |
+
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
| 333 |
+
|
| 334 |
+
print(f"\nBest Precision@5: {study.best_value:.4f}")
|
| 335 |
+
print(f"Best params: {study.best_params}")
|
| 336 |
+
|
| 337 |
+
return study.best_params
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# =============================================================================
|
| 341 |
+
# Main Entry Point
|
| 342 |
+
# =============================================================================
|
| 343 |
+
|
| 344 |
+
def main():
|
| 345 |
+
"""Main training entry point."""
|
| 346 |
+
parser = argparse.ArgumentParser(
|
| 347 |
+
description="Train BrewMatch recommendation models",
|
| 348 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 349 |
+
epilog="""
|
| 350 |
+
Examples:
|
| 351 |
+
uv run train # Train with defaults or saved hyperparameters
|
| 352 |
+
uv run train --tune # Tune hyperparameters, then train
|
| 353 |
+
uv run train --models neural # Train only neural network
|
| 354 |
+
uv run train --tune --neural-trials 100
|
| 355 |
+
""",
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--models",
|
| 359 |
+
nargs="+",
|
| 360 |
+
choices=["baseline", "classical", "neural", "all"],
|
| 361 |
+
default=["all"],
|
| 362 |
+
help="Which models to train (default: all)",
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--tune",
|
| 366 |
+
action="store_true",
|
| 367 |
+
help="Run Optuna hyperparameter tuning before training",
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--neural-trials",
|
| 371 |
+
type=int,
|
| 372 |
+
default=50,
|
| 373 |
+
help="Number of Optuna trials for neural network (default: 50)",
|
| 374 |
+
)
|
| 375 |
+
parser.add_argument(
|
| 376 |
+
"--classical-trials",
|
| 377 |
+
type=int,
|
| 378 |
+
default=30,
|
| 379 |
+
help="Number of Optuna trials for classical ML (default: 30)",
|
| 380 |
+
)
|
| 381 |
+
parser.add_argument(
|
| 382 |
+
"--cv-folds",
|
| 383 |
+
type=int,
|
| 384 |
+
default=3,
|
| 385 |
+
help="Cross-validation folds for tuning (default: 3)",
|
| 386 |
+
)
|
| 387 |
+
parser.add_argument(
|
| 388 |
+
"--device",
|
| 389 |
+
type=str,
|
| 390 |
+
default=None,
|
| 391 |
+
help="Device to train on (cuda/mps/cpu, auto-detected if not specified)",
|
| 392 |
+
)
|
| 393 |
+
args = parser.parse_args()
|
| 394 |
+
|
| 395 |
+
# Device selection
|
| 396 |
+
device = get_device(args.device)
|
| 397 |
+
print_device_info()
|
| 398 |
+
print()
|
| 399 |
+
|
| 400 |
+
# Expand "all" to all models
|
| 401 |
+
models_to_train = args.models
|
| 402 |
+
if "all" in models_to_train:
|
| 403 |
+
models_to_train = ["baseline", "classical", "neural"]
|
| 404 |
+
|
| 405 |
+
print(f"Models to train: {models_to_train}")
|
| 406 |
+
|
| 407 |
+
# Load data
|
| 408 |
+
print("\nLoading processed data...")
|
| 409 |
+
data = load_processed_data()
|
| 410 |
+
train_df = data["train_df"]
|
| 411 |
+
test_df = data["test_df"]
|
| 412 |
+
print(f" Train: {len(train_df)} samples")
|
| 413 |
+
print(f" Test: {len(test_df)} samples")
|
| 414 |
+
|
| 415 |
+
# Load or tune hyperparameters
|
| 416 |
+
saved_params = load_hyperparameters()
|
| 417 |
+
|
| 418 |
+
if args.tune:
|
| 419 |
+
print("\n" + "=" * 60)
|
| 420 |
+
print("HYPERPARAMETER TUNING WITH OPTUNA")
|
| 421 |
+
print("=" * 60)
|
| 422 |
+
|
| 423 |
+
if "neural" in models_to_train:
|
| 424 |
+
neural_params = tune_neural(
|
| 425 |
+
train_df,
|
| 426 |
+
device=str(device),
|
| 427 |
+
n_trials=args.neural_trials,
|
| 428 |
+
n_folds=args.cv_folds,
|
| 429 |
+
)
|
| 430 |
+
saved_params["neural"] = neural_params
|
| 431 |
+
|
| 432 |
+
if "classical" in models_to_train:
|
| 433 |
+
classical_params = tune_classical(
|
| 434 |
+
train_df,
|
| 435 |
+
n_trials=args.classical_trials,
|
| 436 |
+
n_folds=args.cv_folds,
|
| 437 |
+
)
|
| 438 |
+
saved_params["classical"] = classical_params
|
| 439 |
+
|
| 440 |
+
# Save tuned hyperparameters
|
| 441 |
+
save_hyperparameters(saved_params)
|
| 442 |
+
|
| 443 |
+
# Get final hyperparameters
|
| 444 |
+
neural_params = get_neural_params(saved_params)
|
| 445 |
+
classical_params = get_classical_params(saved_params)
|
| 446 |
+
|
| 447 |
+
# Train models
|
| 448 |
+
print("\n" + "=" * 60)
|
| 449 |
+
print("TRAINING MODELS")
|
| 450 |
+
print("=" * 60)
|
| 451 |
+
|
| 452 |
+
baseline_model = None
|
| 453 |
+
classical_model = None
|
| 454 |
+
neural_model = None
|
| 455 |
+
|
| 456 |
+
if "baseline" in models_to_train:
|
| 457 |
+
baseline_model = train_baseline(train_df)
|
| 458 |
+
print()
|
| 459 |
+
|
| 460 |
+
if "classical" in models_to_train:
|
| 461 |
+
classical_model = train_classical(train_df, classical_params)
|
| 462 |
+
print()
|
| 463 |
+
|
| 464 |
+
if "neural" in models_to_train:
|
| 465 |
+
neural_model = train_neural(train_df, neural_params, str(device))
|
| 466 |
+
print()
|
| 467 |
+
|
| 468 |
+
# Save models
|
| 469 |
+
all_params = {
|
| 470 |
+
"neural": neural_params,
|
| 471 |
+
"classical": classical_params,
|
| 472 |
+
}
|
| 473 |
+
save_models(baseline_model, classical_model, neural_model, all_params)
|
| 474 |
+
|
| 475 |
+
print("\nTraining complete!")
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
if __name__ == "__main__":
|
| 479 |
+
main()
|
src/brewmatch/tuning.py
ADDED
|
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter Tuning with Optuna
|
| 3 |
+
|
| 4 |
+
This module provides automated hyperparameter optimization for all BrewMatch models
|
| 5 |
+
using Optuna's Bayesian optimization framework.
|
| 6 |
+
|
| 7 |
+
Optimizes:
|
| 8 |
+
- Neural network: embedding_dim, hidden_dim, learning_rate, margin, batch_size, dropout
|
| 9 |
+
- Classical ML: n_neighbors, method (knn/cosine), normalization
|
| 10 |
+
- Baseline: strategy selection
|
| 11 |
+
|
| 12 |
+
Uses cross-validation for robust evaluation and early pruning for efficiency.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import optuna
|
| 22 |
+
from optuna.pruners import MedianPruner
|
| 23 |
+
from optuna.samplers import TPESampler
|
| 24 |
+
import pandas as pd
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from brewmatch.config import (
|
| 28 |
+
CHECKPOINTS_DIR,
|
| 29 |
+
K_VALUES,
|
| 30 |
+
PROJECT_ROOT,
|
| 31 |
+
TASTE_FEATURES,
|
| 32 |
+
)
|
| 33 |
+
from brewmatch.data import load_processed_data
|
| 34 |
+
from brewmatch.models import (
|
| 35 |
+
NaiveBaselineRecommender,
|
| 36 |
+
ClassicalMLRecommender,
|
| 37 |
+
NeuralRecommender,
|
| 38 |
+
)
|
| 39 |
+
from brewmatch.evaluation import evaluate_model
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
TUNING_DIR = PROJECT_ROOT / "tuning"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def create_cross_validation_splits(
|
| 46 |
+
df: pd.DataFrame,
|
| 47 |
+
n_folds: int = 5,
|
| 48 |
+
random_state: int = 42,
|
| 49 |
+
) -> list[tuple[pd.DataFrame, pd.DataFrame]]:
|
| 50 |
+
"""Create stratified cross-validation splits."""
|
| 51 |
+
np.random.seed(random_state)
|
| 52 |
+
indices = np.random.permutation(len(df))
|
| 53 |
+
fold_size = len(df) // n_folds
|
| 54 |
+
|
| 55 |
+
splits = []
|
| 56 |
+
for i in range(n_folds):
|
| 57 |
+
start = i * fold_size
|
| 58 |
+
end = start + fold_size if i < n_folds - 1 else len(df)
|
| 59 |
+
|
| 60 |
+
val_indices = indices[start:end]
|
| 61 |
+
train_indices = np.concatenate([indices[:start], indices[end:]])
|
| 62 |
+
|
| 63 |
+
train_df = df.iloc[train_indices].reset_index(drop=True)
|
| 64 |
+
val_df = df.iloc[val_indices].reset_index(drop=True)
|
| 65 |
+
splits.append((train_df, val_df))
|
| 66 |
+
|
| 67 |
+
return splits
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def objective_neural(
|
| 71 |
+
trial: optuna.Trial,
|
| 72 |
+
train_df: pd.DataFrame,
|
| 73 |
+
val_df: pd.DataFrame,
|
| 74 |
+
device: str,
|
| 75 |
+
) -> float:
|
| 76 |
+
"""Optuna objective function for neural network hyperparameters."""
|
| 77 |
+
# Sample hyperparameters
|
| 78 |
+
embedding_dim = trial.suggest_int("embedding_dim", 16, 128, step=16)
|
| 79 |
+
hidden_dim = trial.suggest_int("hidden_dim", 32, 256, step=32)
|
| 80 |
+
learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
|
| 81 |
+
margin = trial.suggest_float("margin", 0.1, 1.0)
|
| 82 |
+
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
|
| 83 |
+
dropout = trial.suggest_float("dropout", 0.0, 0.5)
|
| 84 |
+
|
| 85 |
+
# Create and train model
|
| 86 |
+
model = NeuralRecommender(
|
| 87 |
+
embedding_dim=embedding_dim,
|
| 88 |
+
hidden_dim=hidden_dim,
|
| 89 |
+
learning_rate=learning_rate,
|
| 90 |
+
margin=margin,
|
| 91 |
+
device=device,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Train with reduced epochs for tuning
|
| 95 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 96 |
+
model.fit(
|
| 97 |
+
X=X_train,
|
| 98 |
+
metadata=train_df,
|
| 99 |
+
epochs=50, # Reduced for faster tuning
|
| 100 |
+
batch_size=batch_size,
|
| 101 |
+
verbose=False,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Evaluate on validation set
|
| 105 |
+
val_data = {
|
| 106 |
+
"X": val_df[TASTE_FEATURES].values,
|
| 107 |
+
"metadata": val_df,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
metrics = evaluate_model(model, val_data, k_values=[5])
|
| 111 |
+
|
| 112 |
+
# Return primary metric (Precision@5)
|
| 113 |
+
precision_5 = metrics.get("precision@k", {}).get(5, 0.0)
|
| 114 |
+
|
| 115 |
+
return precision_5
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def objective_classical(
|
| 119 |
+
trial: optuna.Trial,
|
| 120 |
+
train_df: pd.DataFrame,
|
| 121 |
+
val_df: pd.DataFrame,
|
| 122 |
+
) -> float:
|
| 123 |
+
"""Optuna objective function for classical ML hyperparameters."""
|
| 124 |
+
# Sample hyperparameters
|
| 125 |
+
method = trial.suggest_categorical("method", ["knn", "cosine"])
|
| 126 |
+
n_neighbors = trial.suggest_int("n_neighbors", 5, 100)
|
| 127 |
+
normalize = trial.suggest_categorical("normalize", [True, False])
|
| 128 |
+
|
| 129 |
+
# Create and train model
|
| 130 |
+
model = ClassicalMLRecommender(
|
| 131 |
+
method=method,
|
| 132 |
+
n_neighbors=n_neighbors,
|
| 133 |
+
normalize=normalize,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 137 |
+
model.fit(X_train, train_df)
|
| 138 |
+
|
| 139 |
+
# Evaluate on validation set
|
| 140 |
+
val_data = {
|
| 141 |
+
"X": val_df[TASTE_FEATURES].values,
|
| 142 |
+
"metadata": val_df,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
metrics = evaluate_model(model, val_data, k_values=[5])
|
| 146 |
+
precision_5 = metrics.get("precision@k", {}).get(5, 0.0)
|
| 147 |
+
|
| 148 |
+
return precision_5
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def tune_neural(
|
| 152 |
+
train_df: pd.DataFrame,
|
| 153 |
+
n_trials: int = 50,
|
| 154 |
+
n_folds: int = 3,
|
| 155 |
+
device: str = "cuda",
|
| 156 |
+
study_name: str = "neural_tuning",
|
| 157 |
+
) -> dict[str, Any]:
|
| 158 |
+
"""
|
| 159 |
+
Tune neural network hyperparameters using Optuna.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
train_df: Training data
|
| 163 |
+
n_trials: Number of optimization trials
|
| 164 |
+
n_folds: Number of cross-validation folds
|
| 165 |
+
device: PyTorch device
|
| 166 |
+
study_name: Name for the Optuna study
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Dictionary with best parameters and study results
|
| 170 |
+
"""
|
| 171 |
+
print(f"\n{'='*60}")
|
| 172 |
+
print("NEURAL NETWORK HYPERPARAMETER TUNING")
|
| 173 |
+
print(f"{'='*60}")
|
| 174 |
+
print(f"Trials: {n_trials}, CV Folds: {n_folds}, Device: {device}")
|
| 175 |
+
|
| 176 |
+
# Create cross-validation splits
|
| 177 |
+
splits = create_cross_validation_splits(train_df, n_folds=n_folds)
|
| 178 |
+
|
| 179 |
+
def cv_objective(trial: optuna.Trial) -> float:
|
| 180 |
+
"""Cross-validated objective."""
|
| 181 |
+
scores = []
|
| 182 |
+
for fold_idx, (fold_train, fold_val) in enumerate(splits):
|
| 183 |
+
score = objective_neural(trial, fold_train, fold_val, device)
|
| 184 |
+
scores.append(score)
|
| 185 |
+
|
| 186 |
+
# Report intermediate value for pruning
|
| 187 |
+
trial.report(np.mean(scores), fold_idx)
|
| 188 |
+
if trial.should_prune():
|
| 189 |
+
raise optuna.TrialPruned()
|
| 190 |
+
|
| 191 |
+
return np.mean(scores)
|
| 192 |
+
|
| 193 |
+
# Create study with TPE sampler and median pruner
|
| 194 |
+
study = optuna.create_study(
|
| 195 |
+
study_name=study_name,
|
| 196 |
+
direction="maximize",
|
| 197 |
+
sampler=TPESampler(seed=42),
|
| 198 |
+
pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=1),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
study.optimize(
|
| 202 |
+
cv_objective,
|
| 203 |
+
n_trials=n_trials,
|
| 204 |
+
show_progress_bar=True,
|
| 205 |
+
gc_after_trial=True,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
print(f"\nBest trial:")
|
| 209 |
+
print(f" Value (Precision@5): {study.best_trial.value:.4f}")
|
| 210 |
+
print(f" Params: {study.best_trial.params}")
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
"best_params": study.best_trial.params,
|
| 214 |
+
"best_value": study.best_trial.value,
|
| 215 |
+
"n_trials": len(study.trials),
|
| 216 |
+
"study_name": study_name,
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def tune_classical(
|
| 221 |
+
train_df: pd.DataFrame,
|
| 222 |
+
n_trials: int = 30,
|
| 223 |
+
n_folds: int = 3,
|
| 224 |
+
study_name: str = "classical_tuning",
|
| 225 |
+
) -> dict[str, Any]:
|
| 226 |
+
"""
|
| 227 |
+
Tune classical ML hyperparameters using Optuna.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
train_df: Training data
|
| 231 |
+
n_trials: Number of optimization trials
|
| 232 |
+
n_folds: Number of cross-validation folds
|
| 233 |
+
study_name: Name for the Optuna study
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Dictionary with best parameters and study results
|
| 237 |
+
"""
|
| 238 |
+
print(f"\n{'='*60}")
|
| 239 |
+
print("CLASSICAL ML HYPERPARAMETER TUNING")
|
| 240 |
+
print(f"{'='*60}")
|
| 241 |
+
print(f"Trials: {n_trials}, CV Folds: {n_folds}")
|
| 242 |
+
|
| 243 |
+
splits = create_cross_validation_splits(train_df, n_folds=n_folds)
|
| 244 |
+
|
| 245 |
+
def cv_objective(trial: optuna.Trial) -> float:
|
| 246 |
+
scores = []
|
| 247 |
+
for fold_train, fold_val in splits:
|
| 248 |
+
score = objective_classical(trial, fold_train, fold_val)
|
| 249 |
+
scores.append(score)
|
| 250 |
+
return np.mean(scores)
|
| 251 |
+
|
| 252 |
+
study = optuna.create_study(
|
| 253 |
+
study_name=study_name,
|
| 254 |
+
direction="maximize",
|
| 255 |
+
sampler=TPESampler(seed=42),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
study.optimize(
|
| 259 |
+
cv_objective,
|
| 260 |
+
n_trials=n_trials,
|
| 261 |
+
show_progress_bar=True,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
print(f"\nBest trial:")
|
| 265 |
+
print(f" Value (Precision@5): {study.best_trial.value:.4f}")
|
| 266 |
+
print(f" Params: {study.best_trial.params}")
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"best_params": study.best_trial.params,
|
| 270 |
+
"best_value": study.best_trial.value,
|
| 271 |
+
"n_trials": len(study.trials),
|
| 272 |
+
"study_name": study_name,
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def train_with_best_params(
|
| 277 |
+
train_df: pd.DataFrame,
|
| 278 |
+
test_df: pd.DataFrame,
|
| 279 |
+
neural_params: dict[str, Any] | None,
|
| 280 |
+
classical_params: dict[str, Any] | None,
|
| 281 |
+
device: str,
|
| 282 |
+
) -> dict[str, Any]:
|
| 283 |
+
"""Train final models with best hyperparameters and evaluate on test set."""
|
| 284 |
+
print(f"\n{'='*60}")
|
| 285 |
+
print("TRAINING FINAL MODELS WITH BEST PARAMETERS")
|
| 286 |
+
print(f"{'='*60}")
|
| 287 |
+
|
| 288 |
+
results = {}
|
| 289 |
+
|
| 290 |
+
test_data = {
|
| 291 |
+
"X": test_df[TASTE_FEATURES].values,
|
| 292 |
+
"metadata": test_df,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
# Train neural with best params
|
| 296 |
+
if neural_params:
|
| 297 |
+
print("\nTraining Neural Network with tuned hyperparameters...")
|
| 298 |
+
model = NeuralRecommender(
|
| 299 |
+
embedding_dim=neural_params["embedding_dim"],
|
| 300 |
+
hidden_dim=neural_params["hidden_dim"],
|
| 301 |
+
learning_rate=neural_params["learning_rate"],
|
| 302 |
+
margin=neural_params["margin"],
|
| 303 |
+
device=device,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 307 |
+
model.fit(
|
| 308 |
+
X=X_train,
|
| 309 |
+
metadata=train_df,
|
| 310 |
+
epochs=100, # Full training
|
| 311 |
+
batch_size=neural_params["batch_size"],
|
| 312 |
+
verbose=True,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
metrics = evaluate_model(model, test_data, k_values=K_VALUES)
|
| 316 |
+
results["neural"] = {
|
| 317 |
+
"params": neural_params,
|
| 318 |
+
"metrics": metrics,
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
# Save model
|
| 322 |
+
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 323 |
+
model.save(CHECKPOINTS_DIR / "neural.pt")
|
| 324 |
+
print(f"Saved tuned neural model to {CHECKPOINTS_DIR / 'neural.pt'}")
|
| 325 |
+
|
| 326 |
+
# Train classical with best params
|
| 327 |
+
if classical_params:
|
| 328 |
+
print("\nTraining Classical ML with tuned hyperparameters...")
|
| 329 |
+
model = ClassicalMLRecommender(
|
| 330 |
+
method=classical_params["method"],
|
| 331 |
+
n_neighbors=classical_params["n_neighbors"],
|
| 332 |
+
normalize=classical_params["normalize"],
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
X_train = train_df[TASTE_FEATURES].values
|
| 336 |
+
model.fit(X_train, train_df)
|
| 337 |
+
|
| 338 |
+
metrics = evaluate_model(model, test_data, k_values=K_VALUES)
|
| 339 |
+
results["classical"] = {
|
| 340 |
+
"params": classical_params,
|
| 341 |
+
"metrics": metrics,
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
model.save(CHECKPOINTS_DIR / "classical.pkl")
|
| 345 |
+
print(f"Saved tuned classical model to {CHECKPOINTS_DIR / 'classical.pkl'}")
|
| 346 |
+
|
| 347 |
+
# Also train baseline for comparison
|
| 348 |
+
print("\nTraining Baseline for comparison...")
|
| 349 |
+
baseline = NaiveBaselineRecommender(strategy="mean")
|
| 350 |
+
baseline.fit(train_df[TASTE_FEATURES].values, train_df)
|
| 351 |
+
baseline_metrics = evaluate_model(baseline, test_data, k_values=K_VALUES)
|
| 352 |
+
results["baseline"] = {"metrics": baseline_metrics}
|
| 353 |
+
baseline.save(CHECKPOINTS_DIR / "baseline.pkl")
|
| 354 |
+
|
| 355 |
+
return results
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def generate_tuning_report(
|
| 359 |
+
neural_results: dict[str, Any] | None,
|
| 360 |
+
classical_results: dict[str, Any] | None,
|
| 361 |
+
final_results: dict[str, Any],
|
| 362 |
+
output_dir: Path,
|
| 363 |
+
) -> str:
|
| 364 |
+
"""Generate a comprehensive tuning report."""
|
| 365 |
+
report = []
|
| 366 |
+
report.append("=" * 60)
|
| 367 |
+
report.append("HYPERPARAMETER TUNING REPORT")
|
| 368 |
+
report.append("=" * 60)
|
| 369 |
+
report.append("")
|
| 370 |
+
|
| 371 |
+
if neural_results:
|
| 372 |
+
report.append("NEURAL NETWORK")
|
| 373 |
+
report.append("-" * 40)
|
| 374 |
+
report.append(f"Trials completed: {neural_results['n_trials']}")
|
| 375 |
+
report.append(f"Best CV Precision@5: {neural_results['best_value']:.4f}")
|
| 376 |
+
report.append("Best hyperparameters:")
|
| 377 |
+
for param, value in neural_results["best_params"].items():
|
| 378 |
+
report.append(f" - {param}: {value}")
|
| 379 |
+
report.append("")
|
| 380 |
+
|
| 381 |
+
if classical_results:
|
| 382 |
+
report.append("CLASSICAL ML")
|
| 383 |
+
report.append("-" * 40)
|
| 384 |
+
report.append(f"Trials completed: {classical_results['n_trials']}")
|
| 385 |
+
report.append(f"Best CV Precision@5: {classical_results['best_value']:.4f}")
|
| 386 |
+
report.append("Best hyperparameters:")
|
| 387 |
+
for param, value in classical_results["best_params"].items():
|
| 388 |
+
report.append(f" - {param}: {value}")
|
| 389 |
+
report.append("")
|
| 390 |
+
|
| 391 |
+
report.append("FINAL TEST SET PERFORMANCE")
|
| 392 |
+
report.append("-" * 40)
|
| 393 |
+
for model_name, result in final_results.items():
|
| 394 |
+
metrics = result["metrics"]
|
| 395 |
+
p5 = metrics.get("precision@k", {}).get(5, 0)
|
| 396 |
+
ndcg5 = metrics.get("ndcg@k", {}).get(5, 0)
|
| 397 |
+
report.append(f"{model_name.upper()}:")
|
| 398 |
+
report.append(f" Precision@5: {p5:.4f}")
|
| 399 |
+
report.append(f" NDCG@5: {ndcg5:.4f}")
|
| 400 |
+
report.append("")
|
| 401 |
+
|
| 402 |
+
# Improvement analysis
|
| 403 |
+
if "baseline" in final_results and "neural" in final_results:
|
| 404 |
+
baseline_p5 = final_results["baseline"]["metrics"].get("precision@k", {}).get(5, 0)
|
| 405 |
+
neural_p5 = final_results["neural"]["metrics"].get("precision@k", {}).get(5, 0)
|
| 406 |
+
if baseline_p5 > 0:
|
| 407 |
+
improvement = (neural_p5 - baseline_p5) / baseline_p5 * 100
|
| 408 |
+
report.append(f"Neural improvement over baseline: {improvement:+.1f}%")
|
| 409 |
+
|
| 410 |
+
report_text = "\n".join(report)
|
| 411 |
+
|
| 412 |
+
# Save report
|
| 413 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 414 |
+
with open(output_dir / "tuning_report.txt", "w") as f:
|
| 415 |
+
f.write(report_text)
|
| 416 |
+
|
| 417 |
+
return report_text
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def main():
|
| 421 |
+
"""Main entry point for hyperparameter tuning."""
|
| 422 |
+
parser = argparse.ArgumentParser(
|
| 423 |
+
description="Tune BrewMatch model hyperparameters with Optuna"
|
| 424 |
+
)
|
| 425 |
+
parser.add_argument(
|
| 426 |
+
"--models",
|
| 427 |
+
nargs="+",
|
| 428 |
+
choices=["neural", "classical", "all"],
|
| 429 |
+
default=["all"],
|
| 430 |
+
help="Which models to tune",
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument(
|
| 433 |
+
"--neural-trials",
|
| 434 |
+
type=int,
|
| 435 |
+
default=50,
|
| 436 |
+
help="Number of trials for neural network tuning",
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
"--classical-trials",
|
| 440 |
+
type=int,
|
| 441 |
+
default=30,
|
| 442 |
+
help="Number of trials for classical ML tuning",
|
| 443 |
+
)
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--cv-folds",
|
| 446 |
+
type=int,
|
| 447 |
+
default=3,
|
| 448 |
+
help="Number of cross-validation folds",
|
| 449 |
+
)
|
| 450 |
+
parser.add_argument(
|
| 451 |
+
"--device",
|
| 452 |
+
type=str,
|
| 453 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 454 |
+
help="Device for neural network training",
|
| 455 |
+
)
|
| 456 |
+
parser.add_argument(
|
| 457 |
+
"--output-dir",
|
| 458 |
+
type=str,
|
| 459 |
+
default=str(TUNING_DIR),
|
| 460 |
+
help="Directory to save tuning results",
|
| 461 |
+
)
|
| 462 |
+
args = parser.parse_args()
|
| 463 |
+
|
| 464 |
+
output_dir = Path(args.output_dir)
|
| 465 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 466 |
+
|
| 467 |
+
models_to_tune = args.models
|
| 468 |
+
if "all" in models_to_tune:
|
| 469 |
+
models_to_tune = ["neural", "classical"]
|
| 470 |
+
|
| 471 |
+
print("HYPERPARAMETER TUNING WITH OPTUNA")
|
| 472 |
+
print("=" * 60)
|
| 473 |
+
print(f"Models to tune: {models_to_tune}")
|
| 474 |
+
print(f"Neural trials: {args.neural_trials}")
|
| 475 |
+
print(f"Classical trials: {args.classical_trials}")
|
| 476 |
+
print(f"CV folds: {args.cv_folds}")
|
| 477 |
+
print(f"Device: {args.device}")
|
| 478 |
+
print(f"Output: {output_dir}")
|
| 479 |
+
|
| 480 |
+
# Load data
|
| 481 |
+
print("\nLoading data...")
|
| 482 |
+
data = load_processed_data()
|
| 483 |
+
train_df = data["train_df"]
|
| 484 |
+
test_df = data["test_df"]
|
| 485 |
+
print(f"Train: {len(train_df)}, Test: {len(test_df)}")
|
| 486 |
+
|
| 487 |
+
# Tune models
|
| 488 |
+
neural_results = None
|
| 489 |
+
classical_results = None
|
| 490 |
+
|
| 491 |
+
if "neural" in models_to_tune:
|
| 492 |
+
neural_results = tune_neural(
|
| 493 |
+
train_df=train_df,
|
| 494 |
+
n_trials=args.neural_trials,
|
| 495 |
+
n_folds=args.cv_folds,
|
| 496 |
+
device=args.device,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# Save neural results
|
| 500 |
+
with open(output_dir / "neural_tuning.json", "w") as f:
|
| 501 |
+
json.dump(neural_results, f, indent=2)
|
| 502 |
+
|
| 503 |
+
if "classical" in models_to_tune:
|
| 504 |
+
classical_results = tune_classical(
|
| 505 |
+
train_df=train_df,
|
| 506 |
+
n_trials=args.classical_trials,
|
| 507 |
+
n_folds=args.cv_folds,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Save classical results
|
| 511 |
+
with open(output_dir / "classical_tuning.json", "w") as f:
|
| 512 |
+
json.dump(classical_results, f, indent=2)
|
| 513 |
+
|
| 514 |
+
# Train final models with best params
|
| 515 |
+
final_results = train_with_best_params(
|
| 516 |
+
train_df=train_df,
|
| 517 |
+
test_df=test_df,
|
| 518 |
+
neural_params=neural_results["best_params"] if neural_results else None,
|
| 519 |
+
classical_params=classical_results["best_params"] if classical_results else None,
|
| 520 |
+
device=args.device,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Save final results
|
| 524 |
+
with open(output_dir / "final_results.json", "w") as f:
|
| 525 |
+
# Convert metrics to JSON-serializable format
|
| 526 |
+
json_results = {}
|
| 527 |
+
for model_name, result in final_results.items():
|
| 528 |
+
json_results[model_name] = {
|
| 529 |
+
"params": result.get("params", {}),
|
| 530 |
+
"metrics": {
|
| 531 |
+
k: {str(kk): vv for kk, vv in v.items()} if isinstance(v, dict) else v
|
| 532 |
+
for k, v in result["metrics"].items()
|
| 533 |
+
},
|
| 534 |
+
}
|
| 535 |
+
json.dump(json_results, f, indent=2)
|
| 536 |
+
|
| 537 |
+
# Generate report
|
| 538 |
+
report = generate_tuning_report(
|
| 539 |
+
neural_results=neural_results,
|
| 540 |
+
classical_results=classical_results,
|
| 541 |
+
final_results=final_results,
|
| 542 |
+
output_dir=output_dir,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
print("\n" + report)
|
| 546 |
+
print(f"\nResults saved to {output_dir}")
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
main()
|
src/brewmatch/utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for BrewMatch."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import pickle
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Optional, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def save_pickle(obj: Any, path: Union[str, Path]) -> None:
|
| 13 |
+
"""Save object to pickle file."""
|
| 14 |
+
path = Path(path)
|
| 15 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
with open(path, "wb") as f:
|
| 17 |
+
pickle.dump(obj, f)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_pickle(path: Union[str, Path]) -> Any:
|
| 21 |
+
"""Load object from pickle file."""
|
| 22 |
+
with open(path, "rb") as f:
|
| 23 |
+
return pickle.load(f)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def save_json(obj: Dict, path: Union[str, Path], indent: int = 2) -> None:
|
| 27 |
+
"""Save dict to JSON file."""
|
| 28 |
+
path = Path(path)
|
| 29 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
with open(path, "w") as f:
|
| 31 |
+
json.dump(obj, f, indent=indent)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_json(path: Union[str, Path]) -> Dict:
|
| 35 |
+
"""Load dict from JSON file."""
|
| 36 |
+
with open(path, "r") as f:
|
| 37 |
+
return json.load(f)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def set_seed(seed: int) -> None:
|
| 41 |
+
"""Set random seeds for reproducibility."""
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
torch.manual_seed(seed)
|
| 44 |
+
if torch.cuda.is_available():
|
| 45 |
+
torch.cuda.manual_seed(seed)
|
| 46 |
+
torch.cuda.manual_seed_all(seed)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 50 |
+
"""
|
| 51 |
+
Compute cosine similarity between vectors.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
a: Query vector(s) of shape (d,) or (n, d)
|
| 55 |
+
b: Reference vectors of shape (m, d)
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Similarity scores of shape (n, m) or (m,)
|
| 59 |
+
"""
|
| 60 |
+
if a.ndim == 1:
|
| 61 |
+
a = a.reshape(1, -1)
|
| 62 |
+
squeeze = True
|
| 63 |
+
else:
|
| 64 |
+
squeeze = False
|
| 65 |
+
|
| 66 |
+
# Normalize
|
| 67 |
+
a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-8)
|
| 68 |
+
b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-8)
|
| 69 |
+
|
| 70 |
+
# Compute similarity
|
| 71 |
+
sim = np.dot(a_norm, b_norm.T)
|
| 72 |
+
|
| 73 |
+
if squeeze:
|
| 74 |
+
sim = sim.squeeze(0)
|
| 75 |
+
|
| 76 |
+
return sim
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 80 |
+
"""
|
| 81 |
+
Compute Euclidean distance between vectors.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
a: Query vector(s) of shape (d,) or (n, d)
|
| 85 |
+
b: Reference vectors of shape (m, d)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Distance scores of shape (n, m) or (m,)
|
| 89 |
+
"""
|
| 90 |
+
if a.ndim == 1:
|
| 91 |
+
a = a.reshape(1, -1)
|
| 92 |
+
squeeze = True
|
| 93 |
+
else:
|
| 94 |
+
squeeze = False
|
| 95 |
+
|
| 96 |
+
# Compute distances using broadcasting
|
| 97 |
+
# ||a - b||^2 = ||a||^2 + ||b||^2 - 2*a.b
|
| 98 |
+
a_sq = np.sum(a ** 2, axis=1, keepdims=True)
|
| 99 |
+
b_sq = np.sum(b ** 2, axis=1).reshape(1, -1)
|
| 100 |
+
dist_sq = a_sq + b_sq - 2 * np.dot(a, b.T)
|
| 101 |
+
dist = np.sqrt(np.maximum(dist_sq, 0))
|
| 102 |
+
|
| 103 |
+
if squeeze:
|
| 104 |
+
dist = dist.squeeze(0)
|
| 105 |
+
|
| 106 |
+
return dist
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def normalize_preferences(
|
| 110 |
+
preferences: Dict[str, float],
|
| 111 |
+
feature_names: list,
|
| 112 |
+
scaler: Optional[Any] = None,
|
| 113 |
+
) -> np.ndarray:
|
| 114 |
+
"""
|
| 115 |
+
Convert user preferences dict to normalized feature vector.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
preferences: Dict mapping feature names to values (0-10 scale)
|
| 119 |
+
feature_names: List of feature names in order
|
| 120 |
+
scaler: Optional sklearn scaler for normalization
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Normalized feature vector
|
| 124 |
+
"""
|
| 125 |
+
# Build vector in correct order
|
| 126 |
+
vector = np.array([preferences.get(name, 5.0) for name in feature_names])
|
| 127 |
+
vector = vector.reshape(1, -1)
|
| 128 |
+
|
| 129 |
+
if scaler is not None:
|
| 130 |
+
vector = scaler.transform(vector)
|
| 131 |
+
|
| 132 |
+
return vector.squeeze()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def format_recommendations(
|
| 136 |
+
indices: np.ndarray,
|
| 137 |
+
similarities: np.ndarray,
|
| 138 |
+
metadata: Any,
|
| 139 |
+
feature_names: list,
|
| 140 |
+
features: np.ndarray,
|
| 141 |
+
) -> list:
|
| 142 |
+
"""
|
| 143 |
+
Format recommendation results for API response.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
indices: Array of recommended coffee indices
|
| 147 |
+
similarities: Similarity scores for recommendations
|
| 148 |
+
metadata: DataFrame or dict with coffee metadata
|
| 149 |
+
feature_names: List of taste feature names
|
| 150 |
+
features: Feature matrix for coffees
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
List of recommendation dicts
|
| 154 |
+
"""
|
| 155 |
+
recommendations = []
|
| 156 |
+
|
| 157 |
+
for idx, sim in zip(indices, similarities):
|
| 158 |
+
rec = {
|
| 159 |
+
"id": int(idx),
|
| 160 |
+
"similarity": float(sim),
|
| 161 |
+
"scores": {
|
| 162 |
+
name: float(features[idx, i])
|
| 163 |
+
for i, name in enumerate(feature_names)
|
| 164 |
+
},
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
# Add metadata if available
|
| 168 |
+
if hasattr(metadata, "iloc"):
|
| 169 |
+
row = metadata.iloc[idx]
|
| 170 |
+
rec["country"] = str(row.get("Country of Origin", "Unknown"))
|
| 171 |
+
rec["processing_method"] = str(row.get("Processing Method", "Unknown"))
|
| 172 |
+
rec["total_cup_points"] = float(row.get("Total Cup Points", 0))
|
| 173 |
+
elif isinstance(metadata, dict):
|
| 174 |
+
rec["country"] = metadata.get("countries", ["Unknown"] * len(indices))[idx]
|
| 175 |
+
rec["processing_method"] = metadata.get(
|
| 176 |
+
"processing_methods", ["Unknown"] * len(indices)
|
| 177 |
+
)[idx]
|
| 178 |
+
|
| 179 |
+
recommendations.append(rec)
|
| 180 |
+
|
| 181 |
+
return recommendations
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|