Spaces:
Sleeping
Sleeping
Commit
·
ff65bb4
0
Parent(s):
Initial commit: OBD Logger with RLHF training system
Browse files- FastAPI-based OBD-II data processing
- Real-time data ingestion and cleaning
- Firebase and MongoDB integration
- RLHF training pipeline with versioned models
- Docker deployment ready
- Security: No hardcoded tokens
- .DS_Store +0 -0
- .gitattributes +35 -0
- .gitignore +11 -0
- Dockerfile +43 -0
- GOOGLE_DRIVE_SETUP.md +336 -0
- MONGODB_SETUP.md +133 -0
- OBD/obd_analyzer.py +215 -0
- OBD/obd_logger.py +374 -0
- README.md +199 -0
- app.py +802 -0
- data.json +26 -0
- data/drive_saver.py +110 -0
- data/firebase_saver.py +315 -0
- data/mongo_saver.py +362 -0
- organization.py +76 -0
- organze.py +184 -0
- requirements.txt +37 -0
- static/check.png +0 -0
- static/edit.png +0 -0
- static/icon.png +0 -0
- static/index.html +16 -0
- static/script.js +230 -0
- static/styles.css +135 -0
- train/README.md +140 -0
- train/__init__.py +8 -0
- train/loader.py +370 -0
- train/rlhf.py +420 -0
- train/saver.py +381 -0
- utils/download.py +154 -0
- utils/mount_drive.py +44 -0
- utils/ul_label.py +206 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
token.json
|
| 3 |
+
service.json
|
| 4 |
+
firebase.json
|
| 5 |
+
|
| 6 |
+
# Security - prevent token leaks
|
| 7 |
+
*.token
|
| 8 |
+
*.key
|
| 9 |
+
*secret*
|
| 10 |
+
*credential*
|
| 11 |
+
*password*
|
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# ── Create and switch to non-root user ──
|
| 4 |
+
RUN useradd -m -u 1000 user
|
| 5 |
+
USER user
|
| 6 |
+
|
| 7 |
+
# ── Set environment and working directory ──
|
| 8 |
+
ENV HOME=/home/user
|
| 9 |
+
WORKDIR $HOME/app
|
| 10 |
+
|
| 11 |
+
# ── Upgrade pip and install dependencies ──
|
| 12 |
+
COPY --chown=user requirements.txt .
|
| 13 |
+
RUN pip install --upgrade pip
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Install latest versions for UL model inference
|
| 17 |
+
RUN pip install --no-cache-dir huggingface_hub xgboost joblib scikit-learn
|
| 18 |
+
|
| 19 |
+
# ── Pre-mount GDrive (no-op if creds not found) ──
|
| 20 |
+
COPY --chown=user utils/mount_drive.py .
|
| 21 |
+
RUN python mount_drive.py || true
|
| 22 |
+
|
| 23 |
+
# ── Copy application source ──
|
| 24 |
+
COPY --chown=user . .
|
| 25 |
+
|
| 26 |
+
# ── Create required folders ──
|
| 27 |
+
RUN mkdir -p $HOME/app/logs \
|
| 28 |
+
$HOME/app/cache \
|
| 29 |
+
$HOME/app/cache/obd_data \
|
| 30 |
+
$HOME/app/cache/obd_data/plots \
|
| 31 |
+
$HOME/app/models/ul
|
| 32 |
+
|
| 33 |
+
# ── Environment variables for HuggingFace model ──
|
| 34 |
+
ENV MODEL_DIR=$HOME/app/models/ul
|
| 35 |
+
ENV HF_MODEL_REPO=BinKhoaLe1812/Driver_Behavior_OBD
|
| 36 |
+
|
| 37 |
+
# ── Models will be downloaded at runtime when app starts ──
|
| 38 |
+
|
| 39 |
+
# ── Default port ──
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
+
# ── Start app ──
|
| 43 |
+
CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
GOOGLE_DRIVE_SETUP.md
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Google Drive Integration Setup Guide
|
| 2 |
+
|
| 3 |
+
This guide explains how to set up Google Drive integration for the OBD Logger application.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
1. **Google Cloud Platform Account**: You need a Google Cloud Platform account
|
| 8 |
+
2. **Google Drive API**: Enable the Google Drive API in your project
|
| 9 |
+
3. **Service Account**: Create a service account with appropriate permissions
|
| 10 |
+
4. **Python Dependencies**: Install the required packages
|
| 11 |
+
|
| 12 |
+
## Installation
|
| 13 |
+
|
| 14 |
+
### 1. Install Dependencies
|
| 15 |
+
|
| 16 |
+
The required packages are already included in `requirements.txt`:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
pip install -r requirements.txt
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
Required packages:
|
| 23 |
+
- `google-auth`
|
| 24 |
+
- `google-auth-httplib2`
|
| 25 |
+
- `google-auth-oauthlib`
|
| 26 |
+
- `google-api-python-client`
|
| 27 |
+
|
| 28 |
+
### 2. Environment Variables
|
| 29 |
+
|
| 30 |
+
Create a `.env` file in your project root with the following variables:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
# Google Drive Configuration
|
| 34 |
+
GDRIVE_CREDENTIALS_JSON={"type":"service_account","project_id":"your-project","private_key_id":"...","private_key":"...","client_email":"...","client_id":"...","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_x509_cert_url":"..."}
|
| 35 |
+
|
| 36 |
+
# Optional: Custom Google Drive Folder ID
|
| 37 |
+
GDRIVE_FOLDER_ID=1r-wefqKbK9k9BeYDW1hXRbx4B-0Fvj5P
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Google Cloud Platform Setup
|
| 41 |
+
|
| 42 |
+
### 1. Create a New Project
|
| 43 |
+
|
| 44 |
+
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
| 45 |
+
2. Click "Select a project" → "New Project"
|
| 46 |
+
3. Enter a project name (e.g., "OBD-Logger-Drive")
|
| 47 |
+
4. Click "Create"
|
| 48 |
+
|
| 49 |
+
### 2. Enable Google Drive API
|
| 50 |
+
|
| 51 |
+
1. In your project, go to "APIs & Services" → "Library"
|
| 52 |
+
2. Search for "Google Drive API"
|
| 53 |
+
3. Click on "Google Drive API"
|
| 54 |
+
4. Click "Enable"
|
| 55 |
+
|
| 56 |
+
### 3. Create Service Account
|
| 57 |
+
|
| 58 |
+
1. Go to "APIs & Services" → "Credentials"
|
| 59 |
+
2. Click "Create Credentials" → "Service Account"
|
| 60 |
+
3. Fill in the service account details:
|
| 61 |
+
- **Name**: `obd-logger-drive`
|
| 62 |
+
- **Description**: `Service account for OBD Logger Google Drive operations`
|
| 63 |
+
4. Click "Create and Continue"
|
| 64 |
+
5. For roles, select "Editor" (or create a custom role with minimal permissions)
|
| 65 |
+
6. Click "Continue" → "Done"
|
| 66 |
+
|
| 67 |
+
### 4. Generate Service Account Key
|
| 68 |
+
|
| 69 |
+
1. In the service accounts list, click on your newly created service account
|
| 70 |
+
2. Go to the "Keys" tab
|
| 71 |
+
3. Click "Add Key" → "Create New Key"
|
| 72 |
+
4. Choose "JSON" format
|
| 73 |
+
5. Click "Create" - this will download a JSON file
|
| 74 |
+
6. **Important**: Keep this file secure and never commit it to version control
|
| 75 |
+
|
| 76 |
+
### 5. Share Google Drive Folder
|
| 77 |
+
|
| 78 |
+
1. Go to [Google Drive](https://drive.google.com/)
|
| 79 |
+
2. Create a new folder or use an existing one
|
| 80 |
+
3. Right-click the folder → "Share"
|
| 81 |
+
4. Add your service account email (found in the JSON file under `client_email`)
|
| 82 |
+
5. Give it "Editor" permissions
|
| 83 |
+
6. Copy the folder ID from the URL (the long string after `/folders/`)
|
| 84 |
+
|
| 85 |
+
## Configuration
|
| 86 |
+
|
| 87 |
+
### 1. Set Up Credentials
|
| 88 |
+
|
| 89 |
+
Copy the contents of your downloaded JSON file and set it as the `GDRIVE_CREDENTIALS_JSON` environment variable:
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
export GDRIVE_CREDENTIALS_JSON='{"type":"service_account","project_id":"your-project",...}'
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Or add it to your `.env` file.
|
| 96 |
+
|
| 97 |
+
### 2. Configure Folder ID
|
| 98 |
+
|
| 99 |
+
Set the `GDRIVE_FOLDER_ID` environment variable to your target folder ID:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
export GDRIVE_FOLDER_ID="your_folder_id_here"
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Usage
|
| 106 |
+
|
| 107 |
+
### Automatic Saving
|
| 108 |
+
|
| 109 |
+
The application automatically uploads cleaned CSV files to Google Drive after processing.
|
| 110 |
+
|
| 111 |
+
### Manual Operations
|
| 112 |
+
|
| 113 |
+
#### Initialize Drive Service
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
from drive_saver import DriveSaver
|
| 117 |
+
|
| 118 |
+
# Create instance
|
| 119 |
+
drive_saver = DriveSaver()
|
| 120 |
+
|
| 121 |
+
# Check if service is available
|
| 122 |
+
if drive_saver.is_service_available():
|
| 123 |
+
print("✅ Google Drive service ready")
|
| 124 |
+
else:
|
| 125 |
+
print("❌ Google Drive service not available")
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
#### Upload CSV File
|
| 129 |
+
|
| 130 |
+
```python
|
| 131 |
+
# Upload to default folder
|
| 132 |
+
success = drive_saver.upload_csv_to_drive("path/to/your/file.csv")
|
| 133 |
+
|
| 134 |
+
# Upload to specific folder
|
| 135 |
+
success = drive_saver.upload_csv_to_drive("path/to/your/file.csv", "custom_folder_id")
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
#### Configuration Management
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
# Get current folder ID
|
| 142 |
+
current_folder = drive_saver.get_folder_id()
|
| 143 |
+
|
| 144 |
+
# Set new folder ID
|
| 145 |
+
drive_saver.set_folder_id("new_folder_id")
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
### Legacy Functions (Backward Compatibility)
|
| 149 |
+
|
| 150 |
+
The module maintains backward compatibility with existing code:
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
from drive_saver import get_drive_service, upload_to_folder
|
| 154 |
+
|
| 155 |
+
# Legacy usage
|
| 156 |
+
service = get_drive_service()
|
| 157 |
+
result = upload_to_folder(service, "file.csv", "folder_id")
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## File Management
|
| 161 |
+
|
| 162 |
+
### Supported File Types
|
| 163 |
+
|
| 164 |
+
- **CSV files**: Primary format for OBD data
|
| 165 |
+
- **Text files**: Other data formats
|
| 166 |
+
- **Binary files**: Limited support
|
| 167 |
+
|
| 168 |
+
### File Naming
|
| 169 |
+
|
| 170 |
+
Files are uploaded with their original names. The system automatically:
|
| 171 |
+
- Preserves file extensions
|
| 172 |
+
- Maintains original timestamps
|
| 173 |
+
- Creates unique names if conflicts exist
|
| 174 |
+
|
| 175 |
+
### Storage Organization
|
| 176 |
+
|
| 177 |
+
- **Default folder**: All files go to the configured default folder
|
| 178 |
+
- **Custom folders**: Specify different folders for different data types
|
| 179 |
+
- **Session-based**: Files are organized by processing sessions
|
| 180 |
+
|
| 181 |
+
## Error Handling
|
| 182 |
+
|
| 183 |
+
### Common Issues
|
| 184 |
+
|
| 185 |
+
1. **Authentication Errors**
|
| 186 |
+
- Check service account credentials
|
| 187 |
+
- Verify API is enabled
|
| 188 |
+
- Ensure service account has proper permissions
|
| 189 |
+
|
| 190 |
+
2. **Permission Errors**
|
| 191 |
+
- Verify folder sharing settings
|
| 192 |
+
- Check service account email is added to folder
|
| 193 |
+
- Ensure "Editor" or higher permissions
|
| 194 |
+
|
| 195 |
+
3. **Quota Exceeded**
|
| 196 |
+
- Monitor Google Drive storage usage
|
| 197 |
+
- Check API quotas in Google Cloud Console
|
| 198 |
+
- Consider upgrading storage plan
|
| 199 |
+
|
| 200 |
+
### Troubleshooting
|
| 201 |
+
|
| 202 |
+
#### Check Service Status
|
| 203 |
+
|
| 204 |
+
```python
|
| 205 |
+
from drive_saver import DriveSaver
|
| 206 |
+
|
| 207 |
+
saver = DriveSaver()
|
| 208 |
+
print(f"Service available: {saver.is_service_available()}")
|
| 209 |
+
print(f"Current folder: {saver.get_folder_id()}")
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
#### Test Connection
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
# Try uploading a small test file
|
| 216 |
+
test_success = drive_saver.upload_csv_to_drive("test.csv")
|
| 217 |
+
if test_success:
|
| 218 |
+
print("✅ Connection test successful")
|
| 219 |
+
else:
|
| 220 |
+
print("❌ Connection test failed")
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
## Security Best Practices
|
| 224 |
+
|
| 225 |
+
### Credential Management
|
| 226 |
+
|
| 227 |
+
- **Never commit** service account JSON to version control
|
| 228 |
+
- **Use environment variables** for sensitive data
|
| 229 |
+
- **Rotate keys** regularly
|
| 230 |
+
- **Limit permissions** to minimum required
|
| 231 |
+
|
| 232 |
+
### Access Control
|
| 233 |
+
|
| 234 |
+
- **Restrict folder access** to necessary users only
|
| 235 |
+
- **Monitor access logs** in Google Drive
|
| 236 |
+
- **Use organization policies** for additional security
|
| 237 |
+
- **Consider VPC Service Controls** for production
|
| 238 |
+
|
| 239 |
+
### Network Security
|
| 240 |
+
|
| 241 |
+
- **HTTPS only** for all API communications
|
| 242 |
+
- **Firewall rules** to restrict access if needed
|
| 243 |
+
- **Audit logs** for suspicious activity
|
| 244 |
+
|
| 245 |
+
## Performance Optimization
|
| 246 |
+
|
| 247 |
+
### Upload Strategies
|
| 248 |
+
|
| 249 |
+
- **Batch uploads** for multiple files
|
| 250 |
+
- **Compression** for large CSV files
|
| 251 |
+
- **Async processing** for non-blocking operations
|
| 252 |
+
|
| 253 |
+
### Monitoring
|
| 254 |
+
|
| 255 |
+
- **Track upload success rates**
|
| 256 |
+
- **Monitor file sizes and upload times**
|
| 257 |
+
- **Set up alerts** for failures
|
| 258 |
+
|
| 259 |
+
## Integration with OBD Logger
|
| 260 |
+
|
| 261 |
+
### Automatic Uploads
|
| 262 |
+
|
| 263 |
+
The system automatically uploads files after:
|
| 264 |
+
1. Data processing completion
|
| 265 |
+
2. CSV cleaning and validation
|
| 266 |
+
3. Feature engineering
|
| 267 |
+
4. Quality checks
|
| 268 |
+
|
| 269 |
+
### File Naming Convention
|
| 270 |
+
|
| 271 |
+
Uploaded files follow the pattern:
|
| 272 |
+
```
|
| 273 |
+
cleaned_{timestamp}.csv
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
Where `{timestamp}` is the normalized timestamp from the processing session.
|
| 277 |
+
|
| 278 |
+
### Error Recovery
|
| 279 |
+
|
| 280 |
+
If uploads fail:
|
| 281 |
+
- Files remain in local storage
|
| 282 |
+
- Errors are logged for debugging
|
| 283 |
+
- Processing continues without interruption
|
| 284 |
+
- Manual retry options available
|
| 285 |
+
|
| 286 |
+
## Advanced Configuration
|
| 287 |
+
|
| 288 |
+
### Custom Scopes
|
| 289 |
+
|
| 290 |
+
Modify the authentication scopes in `drive_saver.py`:
|
| 291 |
+
|
| 292 |
+
```python
|
| 293 |
+
scopes = [
|
| 294 |
+
"https://www.googleapis.com/auth/drive",
|
| 295 |
+
"https://www.googleapis.com/auth/drive.file" # More restrictive
|
| 296 |
+
]
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
### Retry Logic
|
| 300 |
+
|
| 301 |
+
The system includes automatic retry logic for:
|
| 302 |
+
- Network timeouts
|
| 303 |
+
- Rate limiting
|
| 304 |
+
- Temporary service unavailability
|
| 305 |
+
|
| 306 |
+
### Logging
|
| 307 |
+
|
| 308 |
+
Comprehensive logging includes:
|
| 309 |
+
- Upload success/failure
|
| 310 |
+
- File details and metadata
|
| 311 |
+
- Performance metrics
|
| 312 |
+
- Error details for debugging
|
| 313 |
+
|
| 314 |
+
## Support and Maintenance
|
| 315 |
+
|
| 316 |
+
### Regular Tasks
|
| 317 |
+
|
| 318 |
+
1. **Monitor storage usage** in Google Drive
|
| 319 |
+
2. **Check API quotas** in Google Cloud Console
|
| 320 |
+
3. **Review access logs** for security
|
| 321 |
+
4. **Update service account keys** as needed
|
| 322 |
+
|
| 323 |
+
### Troubleshooting Resources
|
| 324 |
+
|
| 325 |
+
- [Google Drive API Documentation](https://developers.google.com/drive/api)
|
| 326 |
+
- [Google Cloud Console](https://console.cloud.google.com/)
|
| 327 |
+
- [Google Drive Help](https://support.google.com/drive/)
|
| 328 |
+
- Application logs and error messages
|
| 329 |
+
|
| 330 |
+
### Getting Help
|
| 331 |
+
|
| 332 |
+
For issues with the OBD Logger integration:
|
| 333 |
+
1. Check application logs
|
| 334 |
+
2. Verify environment variables
|
| 335 |
+
3. Test with simple file uploads
|
| 336 |
+
4. Review Google Cloud Console for errors
|
MONGODB_SETUP.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MongoDB Integration Setup Guide
|
| 2 |
+
|
| 3 |
+
This guide explains how to set up MongoDB integration for the OBD Logger application.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
1. **MongoDB Atlas Account**: You need a MongoDB Atlas account (free tier available)
|
| 8 |
+
2. **Python Dependencies**: Install the required packages
|
| 9 |
+
|
| 10 |
+
## Installation
|
| 11 |
+
|
| 12 |
+
### 1. Install Dependencies
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
pip install pymongo
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Or update your requirements.txt and run:
|
| 19 |
+
```bash
|
| 20 |
+
pip install -r requirements.txt
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### 2. Environment Variables
|
| 24 |
+
|
| 25 |
+
Create a `.env` file in your project root with the following variables:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Google Drive Configuration
|
| 29 |
+
GDRIVE_CREDENTIALS_JSON={"type":"service_account","project_id":"your-project",...}
|
| 30 |
+
|
| 31 |
+
# MongoDB Atlas Connection String
|
| 32 |
+
MONGO_URI=mongodb+srv://username:password@cluster.mongodb.net/obd_logger?retryWrites=true&w=majority
|
| 33 |
+
|
| 34 |
+
# Optional: Custom Google Drive Folder ID
|
| 35 |
+
GDRIVE_FOLDER_ID=1r-wefqKbK9k9BeYDW1hXRbx4B-0Fvj5P
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## MongoDB Atlas Setup
|
| 39 |
+
|
| 40 |
+
### 1. Create Cluster
|
| 41 |
+
1. Go to [MongoDB Atlas](https://cloud.mongodb.com/)
|
| 42 |
+
2. Create a free cluster
|
| 43 |
+
3. Choose your preferred cloud provider and region
|
| 44 |
+
|
| 45 |
+
### 2. Database Access
|
| 46 |
+
1. Go to "Database Access" in the left sidebar
|
| 47 |
+
2. Click "Add New Database User"
|
| 48 |
+
3. Choose "Password" authentication
|
| 49 |
+
4. Set username and password (save these!)
|
| 50 |
+
5. Set privileges to "Read and write to any database"
|
| 51 |
+
|
| 52 |
+
### 3. Network Access
|
| 53 |
+
1. Go to "Network Access" in the left sidebar
|
| 54 |
+
2. Click "Add IP Address"
|
| 55 |
+
3. For development: Click "Allow Access from Anywhere" (0.0.0.0/0)
|
| 56 |
+
4. For production: Add your specific IP addresses
|
| 57 |
+
|
| 58 |
+
### 4. Get Connection String
|
| 59 |
+
1. Go to "Clusters" in the left sidebar
|
| 60 |
+
2. Click "Connect" on your cluster
|
| 61 |
+
3. Choose "Connect your application"
|
| 62 |
+
4. Copy the connection string
|
| 63 |
+
5. Replace `<username>`, `<password>`, and `<dbname>` with your values
|
| 64 |
+
|
| 65 |
+
## Usage
|
| 66 |
+
|
| 67 |
+
### Automatic Saving
|
| 68 |
+
The application now automatically saves cleaned data to both Google Drive and MongoDB after processing.
|
| 69 |
+
|
| 70 |
+
### Manual Operations
|
| 71 |
+
|
| 72 |
+
#### Check MongoDB Status
|
| 73 |
+
```bash
|
| 74 |
+
GET /mongo/status
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
#### Get Session Summary
|
| 78 |
+
```bash
|
| 79 |
+
GET /mongo/sessions
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
#### Query Data
|
| 83 |
+
```bash
|
| 84 |
+
GET /mongo/query?session_id=session_20231201_120000&driving_style=aggressive&limit=100
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
#### Save CSV Directly to MongoDB
|
| 88 |
+
```bash
|
| 89 |
+
POST /mongo/save-csv
|
| 90 |
+
# Upload CSV file with optional session_id parameter
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## Data Structure
|
| 94 |
+
|
| 95 |
+
Each document in MongoDB contains:
|
| 96 |
+
- All OBD sensor data from the original CSV
|
| 97 |
+
- `session_id`: Unique identifier for the data session
|
| 98 |
+
- `imported_at`: Timestamp when data was imported
|
| 99 |
+
- `record_index`: Original row index from CSV
|
| 100 |
+
- `timestamp`: OBD data timestamp (converted to datetime)
|
| 101 |
+
- `driving_style`: Driving style classification
|
| 102 |
+
|
| 103 |
+
## Performance Features
|
| 104 |
+
|
| 105 |
+
- **Indexes**: Automatic creation of indexes on timestamp, driving_style, and session_id
|
| 106 |
+
- **Connection Pooling**: Efficient connection management
|
| 107 |
+
- **Batch Operations**: Bulk insert for better performance
|
| 108 |
+
- **Error Handling**: Graceful fallback if MongoDB is unavailable
|
| 109 |
+
|
| 110 |
+
## Troubleshooting
|
| 111 |
+
|
| 112 |
+
### Connection Issues
|
| 113 |
+
1. Check your MongoDB URI format
|
| 114 |
+
2. Verify network access settings in Atlas
|
| 115 |
+
3. Check username/password credentials
|
| 116 |
+
4. Ensure cluster is running
|
| 117 |
+
|
| 118 |
+
### Data Import Issues
|
| 119 |
+
1. Check CSV file format
|
| 120 |
+
2. Verify data types in your CSV
|
| 121 |
+
3. Check application logs for specific error messages
|
| 122 |
+
|
| 123 |
+
### Performance Issues
|
| 124 |
+
1. Monitor database indexes
|
| 125 |
+
2. Check connection pool settings
|
| 126 |
+
3. Consider data partitioning for large datasets
|
| 127 |
+
|
| 128 |
+
## Security Notes
|
| 129 |
+
|
| 130 |
+
- Never commit your `.env` file to version control
|
| 131 |
+
- Use strong passwords for database users
|
| 132 |
+
- Restrict network access to necessary IP addresses only
|
| 133 |
+
- Consider using VPC peering for production deployments
|
OBD/obd_analyzer.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
DRIVING_STYLE_PASSIVE = "Passive"
|
| 8 |
+
DRIVING_STYLE_MODERATE = "Moderate"
|
| 9 |
+
DRIVING_STYLE_AGGRESSIVE = "Aggressive"
|
| 10 |
+
DRIVING_STYLE_UNKNOWN = "UNKNOWN_STYLE"
|
| 11 |
+
|
| 12 |
+
ROAD_TYPE_LOCAL = "Local"
|
| 13 |
+
ROAD_TYPE_MAIN = "Main"
|
| 14 |
+
ROAD_TYPE_HIGHWAY = "Highway"
|
| 15 |
+
ROAD_TYPE_UNKNOWN = "UNKNOWN_ROAD"
|
| 16 |
+
|
| 17 |
+
TRAFFIC_CONDITION_LIGHT = "Light"
|
| 18 |
+
TRAFFIC_CONDITION_MODERATE = "Moderate"
|
| 19 |
+
TRAFFIC_CONDITION_HEAVY = "Heavy"
|
| 20 |
+
TRAFFIC_CONDITION_UNKNOWN = "UNKNOWN_TRAFFIC"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
KPH_TO_MPS = 1 / 3.6
|
| 24 |
+
G_ACCELERATION = 9.80665
|
| 25 |
+
MIN_MOVING_SPEED_KPH = 2 # have to be moving
|
| 26 |
+
|
| 27 |
+
AGGRESSIVE_RPM_ENTRY_THRESHOLD = 2700
|
| 28 |
+
AGGRESSIVE_THROTTLE_ENTRY_THRESHOLD = 40
|
| 29 |
+
AGGRESSIVE_RPM_HOLD_THRESHOLD = 2300
|
| 30 |
+
HARSH_BRAKING_THRESHOLD_G = -0.25
|
| 31 |
+
|
| 32 |
+
# roc
|
| 33 |
+
AGGRESSIVE_RPM_ROC_THRESHOLD = 500
|
| 34 |
+
AGGRESSIVE_THROTTLE_ROC_THRESHOLD = 45
|
| 35 |
+
POSITIVE_ACCEL_FOR_ROC_CHECK_G = 0.1
|
| 36 |
+
|
| 37 |
+
MODERATE_RPM_THRESHOLD = 2100
|
| 38 |
+
MODERATE_THROTTLE_THRESHOLD = 25
|
| 39 |
+
|
| 40 |
+
MIN_DATA_POINTS_FOR_ROC = 2
|
| 41 |
+
|
| 42 |
+
def load_and_preprocess_data(csv_filepath):
|
| 43 |
+
"""Loads OBD data from CSV and preprocesses it."""
|
| 44 |
+
if not os.path.exists(csv_filepath):
|
| 45 |
+
print(f"Error: File not found at {csv_filepath}")
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
df = pd.read_csv(csv_filepath)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error loading CSV {csv_filepath}: {e}")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
print(f"Successfully loaded {csv_filepath} with {len(df)} rows.")
|
| 55 |
+
|
| 56 |
+
if 'timestamp' not in df.columns:
|
| 57 |
+
print("Error: 'timestamp' column is missing from the CSV.")
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
| 61 |
+
df = df.sort_values(by='timestamp').reset_index(drop=True)
|
| 62 |
+
|
| 63 |
+
df['delta_time_s'] = df['timestamp'].diff().dt.total_seconds()
|
| 64 |
+
if not df.empty:
|
| 65 |
+
df.loc[0, 'delta_time_s'] = 0
|
| 66 |
+
else:
|
| 67 |
+
# Handle empty DataFrame after potential filtering or if it was empty to begin with
|
| 68 |
+
return df # Or handle error appropriately
|
| 69 |
+
|
| 70 |
+
numeric_cols = ['SPEED', 'RPM', 'THROTTLE_POS']
|
| 71 |
+
for col in numeric_cols:
|
| 72 |
+
if col in df.columns:
|
| 73 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 74 |
+
else:
|
| 75 |
+
print(f"Warning: Column {col} not found. It will be filled with NaN.")
|
| 76 |
+
df[col] = np.nan
|
| 77 |
+
|
| 78 |
+
df[numeric_cols] = df[numeric_cols].fillna(method='ffill').fillna(0)
|
| 79 |
+
|
| 80 |
+
if 'SPEED' in df.columns:
|
| 81 |
+
df['SPEED_mps'] = df['SPEED'] * KPH_TO_MPS
|
| 82 |
+
else:
|
| 83 |
+
df['SPEED_mps'] = 0
|
| 84 |
+
|
| 85 |
+
if len(df) >= MIN_DATA_POINTS_FOR_ROC:
|
| 86 |
+
df['acceleration_mps2'] = df['SPEED_mps'].diff() / df['delta_time_s']
|
| 87 |
+
df['acceleration_mps2'] = df['acceleration_mps2'].replace([np.inf, -np.inf], 0).fillna(0)
|
| 88 |
+
if not df.empty: df.loc[0, 'acceleration_mps2'] = 0
|
| 89 |
+
df['acceleration_g'] = df['acceleration_mps2'] / G_ACCELERATION
|
| 90 |
+
if not df.empty: df.loc[0, 'acceleration_g'] = 0
|
| 91 |
+
df['acceleration_g'] = df['acceleration_g'].fillna(0)
|
| 92 |
+
|
| 93 |
+
if 'RPM' in df.columns:
|
| 94 |
+
df['RPM_roc'] = df['RPM'].diff() / df['delta_time_s']
|
| 95 |
+
df['RPM_roc'] = df['RPM_roc'].replace([np.inf, -np.inf], 0).fillna(0)
|
| 96 |
+
if not df.empty: df.loc[0, 'RPM_roc'] = 0
|
| 97 |
+
else:
|
| 98 |
+
df['RPM_roc'] = 0
|
| 99 |
+
|
| 100 |
+
if 'THROTTLE_POS' in df.columns:
|
| 101 |
+
df['THROTTLE_roc'] = df['THROTTLE_POS'].diff() / df['delta_time_s']
|
| 102 |
+
df['THROTTLE_roc'] = df['THROTTLE_roc'].replace([np.inf, -np.inf], 0).fillna(0)
|
| 103 |
+
if not df.empty: df.loc[0, 'THROTTLE_roc'] = 0
|
| 104 |
+
else:
|
| 105 |
+
df['THROTTLE_roc'] = 0
|
| 106 |
+
else:
|
| 107 |
+
# Not enough data for RoC calculations, fill with 0 or handle as error
|
| 108 |
+
df['acceleration_mps2'] = 0
|
| 109 |
+
df['acceleration_g'] = 0
|
| 110 |
+
df['RPM_roc'] = 0
|
| 111 |
+
df['THROTTLE_roc'] = 0
|
| 112 |
+
print("Warning: Not enough data points for full RoC calculations. Output might be limited.")
|
| 113 |
+
|
| 114 |
+
print("Preprocessing complete.")
|
| 115 |
+
return df
|
| 116 |
+
|
| 117 |
+
def classify_driving_style_stateful(df):
|
| 118 |
+
if df.empty or not all(col in df.columns for col in ['RPM', 'THROTTLE_POS', 'SPEED', 'acceleration_g']):
|
| 119 |
+
print("Warning: Missing one or more required columns for stateful classification (RPM, THROTTLE_POS, SPEED, acceleration_g).")
|
| 120 |
+
return pd.Series([DRIVING_STYLE_UNKNOWN] * len(df), index=df.index, dtype=str)
|
| 121 |
+
|
| 122 |
+
driving_styles = [DRIVING_STYLE_UNKNOWN] * len(df)
|
| 123 |
+
current_style = DRIVING_STYLE_PASSIVE
|
| 124 |
+
|
| 125 |
+
for i in range(len(df)):
|
| 126 |
+
rpm = df.loc[i, 'RPM']
|
| 127 |
+
throttle = df.loc[i, 'THROTTLE_POS']
|
| 128 |
+
speed_kph = df.loc[i, 'SPEED']
|
| 129 |
+
accel_g = df.loc[i, 'acceleration_g']
|
| 130 |
+
rpm_roc = df.loc[i, 'RPM_roc']
|
| 131 |
+
throttle_roc = df.loc[i, 'THROTTLE_roc']
|
| 132 |
+
|
| 133 |
+
row_style = DRIVING_STYLE_PASSIVE
|
| 134 |
+
is_moving = speed_kph > MIN_MOVING_SPEED_KPH
|
| 135 |
+
|
| 136 |
+
is_hard_braking_trigger = accel_g < HARSH_BRAKING_THRESHOLD_G and is_moving
|
| 137 |
+
|
| 138 |
+
is_high_abs_rpm_throttle_trigger = (rpm > AGGRESSIVE_RPM_ENTRY_THRESHOLD and
|
| 139 |
+
throttle > AGGRESSIVE_THROTTLE_ENTRY_THRESHOLD and
|
| 140 |
+
is_moving)
|
| 141 |
+
|
| 142 |
+
is_actively_accelerating = accel_g > POSITIVE_ACCEL_FOR_ROC_CHECK_G
|
| 143 |
+
|
| 144 |
+
is_high_roc_trigger = (is_moving and
|
| 145 |
+
is_actively_accelerating and
|
| 146 |
+
(rpm_roc > AGGRESSIVE_RPM_ROC_THRESHOLD or
|
| 147 |
+
throttle_roc > AGGRESSIVE_THROTTLE_ROC_THRESHOLD))
|
| 148 |
+
|
| 149 |
+
is_currently_aggressive_event = is_hard_braking_trigger or is_high_abs_rpm_throttle_trigger or is_high_roc_trigger
|
| 150 |
+
|
| 151 |
+
if current_style == DRIVING_STYLE_AGGRESSIVE:
|
| 152 |
+
if is_currently_aggressive_event:
|
| 153 |
+
row_style = DRIVING_STYLE_AGGRESSIVE
|
| 154 |
+
elif rpm > AGGRESSIVE_RPM_HOLD_THRESHOLD and is_moving:
|
| 155 |
+
row_style = DRIVING_STYLE_AGGRESSIVE
|
| 156 |
+
else:
|
| 157 |
+
if (rpm > MODERATE_RPM_THRESHOLD or throttle > MODERATE_THROTTLE_THRESHOLD) and is_moving:
|
| 158 |
+
row_style = DRIVING_STYLE_MODERATE
|
| 159 |
+
else:
|
| 160 |
+
row_style = DRIVING_STYLE_PASSIVE
|
| 161 |
+
else:
|
| 162 |
+
if is_currently_aggressive_event:
|
| 163 |
+
row_style = DRIVING_STYLE_AGGRESSIVE
|
| 164 |
+
else:
|
| 165 |
+
if (rpm > MODERATE_RPM_THRESHOLD or throttle > MODERATE_THROTTLE_THRESHOLD) and is_moving:
|
| 166 |
+
row_style = DRIVING_STYLE_MODERATE
|
| 167 |
+
else:
|
| 168 |
+
row_style = DRIVING_STYLE_PASSIVE
|
| 169 |
+
|
| 170 |
+
driving_styles[i] = row_style
|
| 171 |
+
current_style = row_style
|
| 172 |
+
|
| 173 |
+
print("Stateful driving style classification complete.")
|
| 174 |
+
return pd.Series(driving_styles, index=df.index)
|
| 175 |
+
|
| 176 |
+
def main():
|
| 177 |
+
parser = argparse.ArgumentParser(description="Analyze OBD CSV log data for driving behavior (stateful).")
|
| 178 |
+
parser.add_argument("csv_filepath", help="Path to the OBD log CSV file.")
|
| 179 |
+
parser.add_argument("--output_csv", help="Path to save the analyzed data CSV file.", default=None)
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
|
| 182 |
+
df = load_and_preprocess_data(args.csv_filepath)
|
| 183 |
+
|
| 184 |
+
if df is None or df.empty:
|
| 185 |
+
print("No data to process after loading or preprocessing.")
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
df['driving_style_analyzed'] = classify_driving_style_stateful(df)
|
| 189 |
+
|
| 190 |
+
print("\n--- Analysis Summary ---")
|
| 191 |
+
print("Driving Style Distribution (Analyzed):")
|
| 192 |
+
counts = df['driving_style_analyzed'].value_counts(dropna=False)
|
| 193 |
+
percentages = df['driving_style_analyzed'].value_counts(normalize=True, dropna=False) * 100
|
| 194 |
+
summary_df = pd.DataFrame({'Count': counts, 'Percentage': percentages})
|
| 195 |
+
print(summary_df)
|
| 196 |
+
|
| 197 |
+
if args.output_csv:
|
| 198 |
+
try:
|
| 199 |
+
output_path = args.output_csv
|
| 200 |
+
output_dir = os.path.dirname(output_path)
|
| 201 |
+
if output_dir and not os.path.exists(output_dir):
|
| 202 |
+
os.makedirs(output_dir)
|
| 203 |
+
df.to_csv(output_path, index=False)
|
| 204 |
+
print(f"\nAnalyzed data saved to {output_path}")
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Error saving output CSV to {args.output_csv}: {e}")
|
| 207 |
+
else:
|
| 208 |
+
print("\n--- First 20 Rows of Analyzed Data (showing key fields) ---")
|
| 209 |
+
display_cols = ['timestamp', 'SPEED', 'RPM', 'THROTTLE_POS', 'acceleration_g', 'driving_style_analyzed']
|
| 210 |
+
display_cols = [col for col in display_cols if col in df.columns]
|
| 211 |
+
if display_cols: print(df[display_cols].head(20))
|
| 212 |
+
else: print("Key display columns not found in DataFrame.")
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
main()
|
OBD/obd_logger.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import obd
|
| 2 |
+
import time
|
| 3 |
+
import datetime
|
| 4 |
+
import csv
|
| 5 |
+
import os
|
| 6 |
+
from collections import deque
|
| 7 |
+
import numpy as np
|
| 8 |
+
import shutil
|
| 9 |
+
import subprocess
|
| 10 |
+
|
| 11 |
+
DRIVING_STYLE_PASSIVE = "Passive"
|
| 12 |
+
DRIVING_STYLE_MODERATE = "Moderate"
|
| 13 |
+
DRIVING_STYLE_AGGRESSIVE = "Aggressive"
|
| 14 |
+
DRIVING_STYLE_UNKNOWN = "UNKNOWN_STYLE"
|
| 15 |
+
|
| 16 |
+
ROAD_TYPE_LOCAL = "Local"
|
| 17 |
+
ROAD_TYPE_MAIN = "Main"
|
| 18 |
+
ROAD_TYPE_HIGHWAY = "Highway"
|
| 19 |
+
ROAD_TYPE_UNKNOWN = "UNKNOWN_ROAD"
|
| 20 |
+
|
| 21 |
+
TRAFFIC_CONDITION_LIGHT = "Light"
|
| 22 |
+
TRAFFIC_CONDITION_MODERATE = "Moderate"
|
| 23 |
+
TRAFFIC_CONDITION_HEAVY = "Heavy"
|
| 24 |
+
TRAFFIC_CONDITION_UNKNOWN = "UNKNOWN_TRAFFIC"
|
| 25 |
+
|
| 26 |
+
# Rolling Average Configuration
|
| 27 |
+
ROLLING_WINDOW_SIZE = 20 # 6 seconds
|
| 28 |
+
MIN_SAMPLES_FOR_CLASSIFICATION = 10
|
| 29 |
+
|
| 30 |
+
# ROC needs tuning
|
| 31 |
+
SHORT_ROC_WINDOW_SIZE = 3
|
| 32 |
+
MIN_SAMPLES_FOR_ROC_CHECK = SHORT_ROC_WINDOW_SIZE
|
| 33 |
+
ROC_THROTTLE_AGGRESSIVE_THRESHOLD = 25.0
|
| 34 |
+
ROC_RPM_AGGRESSIVE_THRESHOLD = 700.0
|
| 35 |
+
ROC_SPEED_AGGRESSIVE_THRESHOLD = 8.0
|
| 36 |
+
MIN_RPM_FOR_AGGRESSIVE_TRIGGER = 1000.0
|
| 37 |
+
AGGRESSIVE_EVENT_COOLDOWN_SAMPLES = 15
|
| 38 |
+
|
| 39 |
+
HIGH_FREQUENCY_PIDS = [
|
| 40 |
+
obd.commands.RPM,
|
| 41 |
+
obd.commands.THROTTLE_POS,
|
| 42 |
+
obd.commands.SPEED,
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
LOW_FREQUENCY_PIDS_POOL = [
|
| 46 |
+
obd.commands.FUEL_PRESSURE,
|
| 47 |
+
obd.commands.ENGINE_LOAD,
|
| 48 |
+
obd.commands.COOLANT_TEMP,
|
| 49 |
+
obd.commands.INTAKE_TEMP,
|
| 50 |
+
obd.commands.TIMING_ADVANCE,
|
| 51 |
+
obd.commands.MAF,
|
| 52 |
+
obd.commands.INTAKE_PRESSURE,
|
| 53 |
+
obd.commands.SHORT_FUEL_TRIM_1,
|
| 54 |
+
obd.commands.LONG_FUEL_TRIM_1,
|
| 55 |
+
obd.commands.SHORT_FUEL_TRIM_2,
|
| 56 |
+
obd.commands.LONG_FUEL_TRIM_2,
|
| 57 |
+
obd.commands.COMMANDED_EQUIV_RATIO,
|
| 58 |
+
obd.commands.O2_B1S2,
|
| 59 |
+
obd.commands.O2_B2S2,
|
| 60 |
+
obd.commands.O2_S1_WR_VOLTAGE,
|
| 61 |
+
obd.commands.COMMANDED_EGR,
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
ALL_PIDS_TO_LOG = HIGH_FREQUENCY_PIDS + LOW_FREQUENCY_PIDS_POOL
|
| 65 |
+
|
| 66 |
+
CSV_FILENAME_BASE = "obd_data_log"
|
| 67 |
+
# Define new structured log directories relative to the OBD_Logger/OBD directory
|
| 68 |
+
LOGS_BASE_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "logs") # Corrected: Up two levels to Base, then into logs
|
| 69 |
+
ORIGINAL_CSV_DIR = os.path.join(LOGS_BASE_DIR, "OriginalCSV")
|
| 70 |
+
DUPLICATE_CSV_DIR = os.path.join(LOGS_BASE_DIR, "DuplicateCSV")
|
| 71 |
+
|
| 72 |
+
WIFI_ADAPTER_HOST = "192.168.0.10"
|
| 73 |
+
WIFI_ADAPTER_PORT = 35000
|
| 74 |
+
|
| 75 |
+
WIFI_PROTOCOL = "6"
|
| 76 |
+
USE_WIFI_SETTINGS = False # using socat to mimic serial connection
|
| 77 |
+
|
| 78 |
+
def get_pid_value(connection, pid_command):
|
| 79 |
+
"""Queries a PID and returns its value, or None if not available or error."""
|
| 80 |
+
try:
|
| 81 |
+
response = connection.query(pid_command, force=True)
|
| 82 |
+
if response.is_null() or response.value is None:
|
| 83 |
+
return None
|
| 84 |
+
if hasattr(response.value, 'magnitude'):
|
| 85 |
+
return response.value.magnitude
|
| 86 |
+
return response.value
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error querying {pid_command.name}: {e}")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
def perform_logging_session():
|
| 92 |
+
connection = None
|
| 93 |
+
print("Starting OBD-II Data Logger...")
|
| 94 |
+
print("Classifications (Style, Road, Traffic) will be determined automatically.")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
initial_driving_style = ""
|
| 98 |
+
initial_road_type = ""
|
| 99 |
+
initial_traffic_condition = ""
|
| 100 |
+
|
| 101 |
+
BASE_LOG_INTERVAL = .3 # for high frequency data
|
| 102 |
+
LOW_FREQUENCY_GROUP_POLL_INTERVAL = 90.0 # Interval in seconds to poll one group of LF PIDs
|
| 103 |
+
NUM_LOW_FREQUENCY_GROUPS = 3
|
| 104 |
+
|
| 105 |
+
# Prepare Low-Frequency PID groups
|
| 106 |
+
low_frequency_pid_groups = []
|
| 107 |
+
if LOW_FREQUENCY_PIDS_POOL:
|
| 108 |
+
chunk_size = (len(LOW_FREQUENCY_PIDS_POOL) + NUM_LOW_FREQUENCY_GROUPS - 1) // NUM_LOW_FREQUENCY_GROUPS
|
| 109 |
+
for i in range(0, len(LOW_FREQUENCY_PIDS_POOL), chunk_size):
|
| 110 |
+
low_frequency_pid_groups.append(LOW_FREQUENCY_PIDS_POOL[i:i + chunk_size])
|
| 111 |
+
|
| 112 |
+
if not low_frequency_pid_groups: # Handle case with no LF PIDs
|
| 113 |
+
low_frequency_pid_groups.append([])
|
| 114 |
+
NUM_LOW_FREQUENCY_GROUPS = 1
|
| 115 |
+
|
| 116 |
+
last_low_frequency_group_poll_time = time.monotonic()
|
| 117 |
+
current_low_frequency_group_index = 0
|
| 118 |
+
|
| 119 |
+
current_pid_values = {pid.name: '' for pid in ALL_PIDS_TO_LOG}
|
| 120 |
+
|
| 121 |
+
# Create log directories
|
| 122 |
+
for dir_path in [ORIGINAL_CSV_DIR, DUPLICATE_CSV_DIR]: # Add ANALYZED_OUTPUT_DIR if used
|
| 123 |
+
try:
|
| 124 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 125 |
+
print(f"Ensured directory exists: {dir_path}")
|
| 126 |
+
except OSError as e:
|
| 127 |
+
print(f"Error creating directory {dir_path}: {e}. Attempting to use current directory.")
|
| 128 |
+
# Fallback logic may be needed if creation fails critically
|
| 129 |
+
if dir_path == ORIGINAL_CSV_DIR: # Critical for saving original log
|
| 130 |
+
print("Cannot create original log directory. Exiting.")
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
current_session_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 134 |
+
csv_file_name_only = f"{CSV_FILENAME_BASE}_{current_session_timestamp}.csv"
|
| 135 |
+
original_csv_filepath = os.path.join(ORIGINAL_CSV_DIR, csv_file_name_only)
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
if USE_WIFI_SETTINGS:
|
| 139 |
+
print(f"Attempting to connect to WiFi adapter at {WIFI_ADAPTER_HOST}:{WIFI_ADAPTER_PORT} using protocol {WIFI_PROTOCOL}...")
|
| 140 |
+
connection = obd.OBD(protocol=WIFI_PROTOCOL,
|
| 141 |
+
host=WIFI_ADAPTER_HOST,
|
| 142 |
+
port=WIFI_ADAPTER_PORT,
|
| 143 |
+
fast=False,
|
| 144 |
+
timeout=30)
|
| 145 |
+
else:
|
| 146 |
+
print("Attempting to connect via socat PTY /dev/ttys011...")
|
| 147 |
+
connection = obd.OBD("/dev/ttys086", fast=True, timeout=30) # Auto-scan for USB/Bluetooth
|
| 148 |
+
|
| 149 |
+
if not connection.is_connected():
|
| 150 |
+
print("Failed to connect to OBD-II adapter.")
|
| 151 |
+
print(f"Connection status: {connection.status()}")
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
print(f"Successfully connected to OBD-II adapter: {connection.port_name()}")
|
| 155 |
+
print(f"Adapter status: {connection.status()}")
|
| 156 |
+
print(f"Supported PIDs (sample):")
|
| 157 |
+
supported_commands = connection.supported_commands
|
| 158 |
+
for i, cmd in enumerate(supported_commands):
|
| 159 |
+
print(f" - {cmd.name}")
|
| 160 |
+
if not supported_commands:
|
| 161 |
+
print("No commands")
|
| 162 |
+
|
| 163 |
+
# Creating initial full PID sample to have fully populated rows from beginning
|
| 164 |
+
print("\nPerforming initial full PID sample...")
|
| 165 |
+
initial_log_entry = {
|
| 166 |
+
'timestamp': datetime.datetime.now().isoformat(),
|
| 167 |
+
'driving_style': initial_driving_style,
|
| 168 |
+
'road_type': initial_road_type,
|
| 169 |
+
'traffic_condition': initial_traffic_condition
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
print("Polling initial High-Frequency PIDs...")
|
| 173 |
+
for pid_command in HIGH_FREQUENCY_PIDS:
|
| 174 |
+
value = get_pid_value(connection, pid_command)
|
| 175 |
+
current_pid_values[pid_command.name] = value if value is not None else ''
|
| 176 |
+
initial_log_entry[pid_command.name] = current_pid_values[pid_command.name]
|
| 177 |
+
|
| 178 |
+
print("Polling initial Low-Frequency PIDs (all groups)...")
|
| 179 |
+
if low_frequency_pid_groups and low_frequency_pid_groups[0]: # Check if there are any LF PIDs
|
| 180 |
+
for group in low_frequency_pid_groups:
|
| 181 |
+
for pid_command in group:
|
| 182 |
+
value = get_pid_value(connection, pid_command)
|
| 183 |
+
current_pid_values[pid_command.name] = value if value is not None else ''
|
| 184 |
+
initial_log_entry[pid_command.name] = current_pid_values[pid_command.name]
|
| 185 |
+
else:
|
| 186 |
+
print("No Low-Frequency PIDs to poll for initial sample.")
|
| 187 |
+
|
| 188 |
+
for pid_obj in ALL_PIDS_TO_LOG:
|
| 189 |
+
if pid_obj.name not in initial_log_entry:
|
| 190 |
+
initial_log_entry[pid_obj.name] = '' # Default to empty if somehow missed
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"An error occurred during connection or initial PID sample: {e}")
|
| 194 |
+
if connection and connection.is_connected():
|
| 195 |
+
connection.close()
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
file_exists = os.path.isfile(original_csv_filepath)
|
| 199 |
+
try:
|
| 200 |
+
with open(original_csv_filepath, 'a', newline='') as csvfile:
|
| 201 |
+
# Add new columns for analyzer output, they will be empty initially from logger
|
| 202 |
+
header_names = ['timestamp',
|
| 203 |
+
'driving_style', 'road_type', 'traffic_condition', # Original placeholder columns
|
| 204 |
+
'driving_style_analyzed', 'road_type_analyzed', 'traffic_condition_analyzed' # For analyzer
|
| 205 |
+
] + [pid.name for pid in ALL_PIDS_TO_LOG]
|
| 206 |
+
|
| 207 |
+
# Remove duplicates if any PID name is already in the first part
|
| 208 |
+
processed_headers = []
|
| 209 |
+
for item in header_names:
|
| 210 |
+
if item not in processed_headers:
|
| 211 |
+
processed_headers.append(item)
|
| 212 |
+
header_names = processed_headers
|
| 213 |
+
|
| 214 |
+
writer = csv.DictWriter(csvfile, fieldnames=header_names)
|
| 215 |
+
|
| 216 |
+
if not file_exists or os.path.getsize(original_csv_filepath) == 0:
|
| 217 |
+
writer.writeheader()
|
| 218 |
+
print(f"Created new CSV file: {original_csv_filepath} with headers: {header_names}")
|
| 219 |
+
|
| 220 |
+
if initial_log_entry:
|
| 221 |
+
# Add placeholder columns for analyzer to the initial entry
|
| 222 |
+
initial_log_entry['driving_style_analyzed'] = ''
|
| 223 |
+
initial_log_entry['road_type_analyzed'] = ''
|
| 224 |
+
initial_log_entry['traffic_condition_analyzed'] = ''
|
| 225 |
+
writer.writerow(initial_log_entry)
|
| 226 |
+
csvfile.flush()
|
| 227 |
+
print(f"Logged initial full sample. Style: {initial_driving_style}, Road: {initial_road_type}, Traffic: {initial_traffic_condition}.")
|
| 228 |
+
|
| 229 |
+
last_low_frequency_group_poll_time = time.monotonic()
|
| 230 |
+
current_low_frequency_group_index = 0
|
| 231 |
+
|
| 232 |
+
print(f"\nLogging high-frequency data every {BASE_LOG_INTERVAL} second(s).")
|
| 233 |
+
print(f"Polling one group of low-frequency PIDs every {LOW_FREQUENCY_GROUP_POLL_INTERVAL} second(s).")
|
| 234 |
+
print(f"Low-frequency PIDs divided into {len(low_frequency_pid_groups)} groups.")
|
| 235 |
+
|
| 236 |
+
log_count = 0
|
| 237 |
+
while True:
|
| 238 |
+
loop_start_time = time.monotonic()
|
| 239 |
+
current_datetime = datetime.datetime.now()
|
| 240 |
+
timestamp_iso = current_datetime.isoformat()
|
| 241 |
+
|
| 242 |
+
hf_reads = 0
|
| 243 |
+
for pid_command in HIGH_FREQUENCY_PIDS:
|
| 244 |
+
value = get_pid_value(connection, pid_command)
|
| 245 |
+
current_pid_values[pid_command.name] = value if value is not None else ''
|
| 246 |
+
if value is not None:
|
| 247 |
+
hf_reads += 1
|
| 248 |
+
|
| 249 |
+
lf_reads_this_cycle = 0
|
| 250 |
+
lf_group_polled_this_cycle = "None"
|
| 251 |
+
if low_frequency_pid_groups and (time.monotonic() - last_low_frequency_group_poll_time) >= LOW_FREQUENCY_GROUP_POLL_INTERVAL:
|
| 252 |
+
group_to_poll = low_frequency_pid_groups[current_low_frequency_group_index]
|
| 253 |
+
lf_group_polled_this_cycle = f"Group {current_low_frequency_group_index + 1}/{len(low_frequency_pid_groups)}"
|
| 254 |
+
|
| 255 |
+
for pid_command in group_to_poll:
|
| 256 |
+
value = get_pid_value(connection, pid_command)
|
| 257 |
+
current_pid_values[pid_command.name] = value if value is not None else ''
|
| 258 |
+
if value is not None:
|
| 259 |
+
lf_reads_this_cycle +=1
|
| 260 |
+
else:
|
| 261 |
+
print(f"Warning: Could not read LF PID {pid_command.name}")
|
| 262 |
+
|
| 263 |
+
last_low_frequency_group_poll_time = time.monotonic()
|
| 264 |
+
current_low_frequency_group_index = (current_low_frequency_group_index + 1) % len(low_frequency_pid_groups)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
final_log_entry = {
|
| 268 |
+
'timestamp': timestamp_iso,
|
| 269 |
+
'driving_style': initial_driving_style,
|
| 270 |
+
'road_type': initial_road_type,
|
| 271 |
+
'traffic_condition': initial_traffic_condition,
|
| 272 |
+
'driving_style_analyzed': '',
|
| 273 |
+
'road_type_analyzed': '',
|
| 274 |
+
'traffic_condition_analyzed': ''
|
| 275 |
+
}
|
| 276 |
+
# Add all PID values for this cycle from current_pid_values
|
| 277 |
+
for pid_obj in ALL_PIDS_TO_LOG:
|
| 278 |
+
final_log_entry[pid_obj.name] = current_pid_values.get(pid_obj.name, '')
|
| 279 |
+
|
| 280 |
+
writer.writerow(final_log_entry)
|
| 281 |
+
csvfile.flush()
|
| 282 |
+
|
| 283 |
+
log_count += 1
|
| 284 |
+
if log_count % 10 == 0:
|
| 285 |
+
status_msg = f"Logged entry {log_count} - HF PIDs Read: {hf_reads}/{len(HIGH_FREQUENCY_PIDS)}"
|
| 286 |
+
if lf_reads_this_cycle > 0 or lf_group_polled_this_cycle != "None":
|
| 287 |
+
status_msg += f" - LF PIDs ({lf_group_polled_this_cycle}) Read: {lf_reads_this_cycle}/unknown_total_for_group_easily"
|
| 288 |
+
print(status_msg)
|
| 289 |
+
|
| 290 |
+
elapsed_time_in_loop = time.monotonic() - loop_start_time
|
| 291 |
+
sleep_duration = max(0, BASE_LOG_INTERVAL - elapsed_time_in_loop)
|
| 292 |
+
time.sleep(sleep_duration)
|
| 293 |
+
|
| 294 |
+
except KeyboardInterrupt:
|
| 295 |
+
print("\nStopping data logging due to user interruption (Ctrl+C).")
|
| 296 |
+
except Exception as e:
|
| 297 |
+
print(f"An error occurred during logging: {e}")
|
| 298 |
+
finally:
|
| 299 |
+
if connection and connection.is_connected():
|
| 300 |
+
print("Closing OBD-II connection.")
|
| 301 |
+
connection.close()
|
| 302 |
+
print(f"Data logging stopped. Original CSV file '{original_csv_filepath}' saved.")
|
| 303 |
+
|
| 304 |
+
return original_csv_filepath
|
| 305 |
+
|
| 306 |
+
def duplicate_csv(original_filepath):
|
| 307 |
+
if not original_filepath or not os.path.exists(original_filepath):
|
| 308 |
+
print(f"Error: Original CSV not found for duplication: {original_filepath}")
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
# Ensure DUPLICATE_CSV_DIR exists (it should have been created by perform_logging_session)
|
| 312 |
+
os.makedirs(DUPLICATE_CSV_DIR, exist_ok=True)
|
| 313 |
+
|
| 314 |
+
# Get just the filename from the original path
|
| 315 |
+
original_filename = os.path.basename(original_filepath)
|
| 316 |
+
base, ext = os.path.splitext(original_filename)
|
| 317 |
+
|
| 318 |
+
# Construct new filename for the duplicate
|
| 319 |
+
duplicate_filename = f"{base}_to_analyze{ext}" # Suffix to distinguish
|
| 320 |
+
duplicate_filepath = os.path.join(DUPLICATE_CSV_DIR, duplicate_filename)
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
shutil.copy2(original_filepath, duplicate_filepath)
|
| 324 |
+
print(f"Successfully duplicated CSV to: {duplicate_filepath}")
|
| 325 |
+
return duplicate_filepath
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Error duplicating CSV {original_filepath} to {duplicate_filepath}: {e}")
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
def run_analyzer_on_csv(csv_to_analyze_path):
|
| 331 |
+
if not csv_to_analyze_path or not os.path.exists(csv_to_analyze_path):
|
| 332 |
+
print(f"Error: Analyzer input CSV not found: {csv_to_analyze_path}")
|
| 333 |
+
return
|
| 334 |
+
|
| 335 |
+
# Analyzer script is in the same directory as this logger script
|
| 336 |
+
analyzer_script_path = os.path.join(os.path.dirname(__file__), "obd_analyzer.py")
|
| 337 |
+
|
| 338 |
+
if not os.path.exists(analyzer_script_path):
|
| 339 |
+
print(f"CRITICAL Error: Analyzer script not found at {analyzer_script_path}")
|
| 340 |
+
return
|
| 341 |
+
|
| 342 |
+
analyzed_file_basename = os.path.basename(csv_to_analyze_path).replace("_to_analyze.csv", "_final_analyzed.csv")
|
| 343 |
+
final_output_path = os.path.join(DUPLICATE_CSV_DIR, analyzed_file_basename)
|
| 344 |
+
|
| 345 |
+
command = [
|
| 346 |
+
"python",
|
| 347 |
+
analyzer_script_path,
|
| 348 |
+
csv_to_analyze_path,
|
| 349 |
+
"--output_csv",
|
| 350 |
+
final_output_path
|
| 351 |
+
]
|
| 352 |
+
|
| 353 |
+
print(f"Running analyzer: {' '.join(command)}")
|
| 354 |
+
try:
|
| 355 |
+
process = subprocess.run(command, check=True, capture_output=True, text=True, cwd=os.path.dirname(__file__))
|
| 356 |
+
print("Analyzer Output:\n", process.stdout)
|
| 357 |
+
if process.stderr: print("Analyzer Errors:\n", process.stderr)
|
| 358 |
+
print(f"Analyzer finished. Output saved to {final_output_path}")
|
| 359 |
+
except subprocess.CalledProcessError as e:
|
| 360 |
+
print(f"Error running analyzer: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}")
|
| 361 |
+
except FileNotFoundError:
|
| 362 |
+
print(f"Error: 'python' or analyzer script not found ({analyzer_script_path}).")
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
original_log_file = perform_logging_session()
|
| 366 |
+
|
| 367 |
+
if original_log_file and os.path.exists(original_log_file):
|
| 368 |
+
duplicated_log_file = duplicate_csv(original_log_file)
|
| 369 |
+
|
| 370 |
+
if duplicated_log_file:
|
| 371 |
+
run_analyzer_on_csv(duplicated_log_file)
|
| 372 |
+
print(f"Process complete. Original log: {original_log_file}, Analyzed log copy: {duplicated_log_file}")
|
| 373 |
+
else:
|
| 374 |
+
print("OBD logging did not produce a valid CSV file. Skipping analysis.")
|
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: OBD Logger
|
| 3 |
+
emoji: 🚗
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
short_description: OBD-logging FastAPI server with data processing pipelines
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# OBD Logger
|
| 13 |
+
|
| 14 |
+
A comprehensive OBD-II data logging and processing system built with FastAPI, featuring advanced data cleaning, Google Drive integration, MongoDB storage capabilities, and **Reinforcement Learning from Human Feedback (RLHF)** for driver behavior classification.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- **Real-time OBD-II Data Ingestion**: Stream and process OBD sensor data in real-time
|
| 19 |
+
- **Advanced Data Cleaning**: Intelligent gap detection, KNN imputation, and outlier handling
|
| 20 |
+
- **Multi-Storage Architecture**:
|
| 21 |
+
- Google Drive integration for CSV storage
|
| 22 |
+
- Firebase for structured data storage and querying
|
| 23 |
+
- MongoDB Atlas for structured data storage and querying
|
| 24 |
+
- **Driver Behavior Classification**: XGBoost-based ML model for driving style prediction
|
| 25 |
+
- **RLHF Training System**: Continuous model improvement through human feedback
|
| 26 |
+
- **Data Visualization**: Automatic generation of correlation heatmaps and trend plots
|
| 27 |
+
- **RESTful API**: Comprehensive endpoints for data management and retrieval
|
| 28 |
+
- **Web Dashboard**: User-friendly interface for monitoring and control
|
| 29 |
+
- **Model Versioning**: Semantic versioning (1.0, 1.1, 1.2, etc.) with Hugging Face integration
|
| 30 |
+
|
| 31 |
+
## Architecture
|
| 32 |
+
|
| 33 |
+
The application is structured into modular components:
|
| 34 |
+
|
| 35 |
+
- **`app.py`**: Main FastAPI application with data processing pipeline and RLHF endpoints
|
| 36 |
+
- **`data/`**: Storage and persistence modules
|
| 37 |
+
- **`drive_saver.py`**: Google Drive operations and file management
|
| 38 |
+
- **`mongo_saver.py`**: MongoDB operations and data persistence
|
| 39 |
+
- **`firebase_saver.py`**: Firebase operations and data persistence
|
| 40 |
+
- **`train/`**: RLHF training system
|
| 41 |
+
- **`loader.py`**: Load labeled data from Firebase storage with original dataset tracking
|
| 42 |
+
- **`saver.py`**: Save trained models to Hugging Face Hub with semantic versioning
|
| 43 |
+
- **`rlhf.py`**: Main RLHF training pipeline for continuous model improvement
|
| 44 |
+
- **`OBD/`**: OBD-specific modules for data analysis and logging
|
| 45 |
+
- **`utils/`**: Utility modules for model management and data processing
|
| 46 |
+
|
| 47 |
+
## Quick Start
|
| 48 |
+
|
| 49 |
+
1. **Install Dependencies**:
|
| 50 |
+
```bash
|
| 51 |
+
pip install -r requirements.txt
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
2. **Set Environment Variables**:
|
| 55 |
+
- `GDRIVE_CREDENTIALS_JSON`: Google Service Account credentials
|
| 56 |
+
- `FIREBASE_SERVICE_ACCOUNT_JSON`: Firebase connection string
|
| 57 |
+
- `FIREBASE_ADMIN_JSON`: Firebase Admin SDK credentials
|
| 58 |
+
- `HF_TOKEN`: Hugging Face authentication token
|
| 59 |
+
- `HF_MODEL_REPO`: Hugging Face model repository (default: `BinKhoaLe1812/Driver_Behavior_OBD`)
|
| 60 |
+
- `MODEL_DIR`: Local model directory (default: `/app/models/ul`)
|
| 61 |
+
|
| 62 |
+
3. **Run the Application**:
|
| 63 |
+
```bash
|
| 64 |
+
uvicorn app:app --reload
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
4. **Access the Dashboard**:
|
| 68 |
+
- Web UI: `http://localhost:8000/ui`
|
| 69 |
+
- API Docs: `http://localhost:8000/docs`
|
| 70 |
+
|
| 71 |
+
## Data Processing Pipeline
|
| 72 |
+
|
| 73 |
+
1. **Ingestion**: Real-time streaming or bulk CSV upload
|
| 74 |
+
2. **Cleaning**: Automatic gap detection and KNN imputation
|
| 75 |
+
3. **Feature Engineering**: Derived metrics and sensor combinations
|
| 76 |
+
4. **Storage**: Simultaneous save to Google Drive, Firebase, and MongoDB
|
| 77 |
+
5. **Driver Behavior Classification**: XGBoost model prediction on processed data
|
| 78 |
+
6. **RLHF Training**: Continuous model improvement through human feedback
|
| 79 |
+
7. **Visualization**: Correlation analysis and trend plots
|
| 80 |
+
|
| 81 |
+
## API Endpoints
|
| 82 |
+
|
| 83 |
+
### Data Ingestion
|
| 84 |
+
- `POST /ingest`: Stream OBD data
|
| 85 |
+
- `POST /upload-csv/`: Bulk CSV upload
|
| 86 |
+
|
| 87 |
+
### Data Retrieval
|
| 88 |
+
- `GET /download/{filename}`: Download cleaned CSV
|
| 89 |
+
- `GET /events`: Get processing status
|
| 90 |
+
|
| 91 |
+
### MongoDB Operations
|
| 92 |
+
- `GET /mongo/status`: Check MongoDB connection
|
| 93 |
+
- `GET /mongo/sessions`: Get data session summaries
|
| 94 |
+
- `GET /mongo/query`: Query data with filters
|
| 95 |
+
- `POST /mongo/save-csv`: Direct CSV to MongoDB
|
| 96 |
+
|
| 97 |
+
### RLHF Training System
|
| 98 |
+
- `POST /rlhf/train`: Trigger RLHF training session
|
| 99 |
+
- `GET /rlhf/status`: Get RLHF system status and available labeled data
|
| 100 |
+
- `GET /rlhf/trained-datasets`: List datasets already used for training
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
### Firebase Storage
|
| 104 |
+
- Structured data storage with automatic versioning
|
| 105 |
+
- **`skyledge/raw/`**: Original OBD data files
|
| 106 |
+
- **`skyledge/processed/`**: Cleaned and processed data
|
| 107 |
+
- **`skyledge/labeled/`**: Human-labeled data for RLHF training
|
| 108 |
+
- **`skyledge/labeled/trained.txt`**: Tracks processed datasets to avoid retraining
|
| 109 |
+
|
| 110 |
+
### Hugging Face Hub
|
| 111 |
+
- **Model Repository**: `BinKhoaLe1812/Driver_Behavior_OBD`
|
| 112 |
+
- **Semantic Versioning**: v1.0, v1.1, v1.2, ..., v2.0, etc.
|
| 113 |
+
- **Model Components**: XGBoost model, label encoder, scaler
|
| 114 |
+
- **Metadata**: Training logs, performance metrics, dataset information
|
| 115 |
+
|
| 116 |
+
## RLHF Training System
|
| 117 |
+
|
| 118 |
+
### Overview
|
| 119 |
+
The Reinforcement Learning from Human Feedback (RLHF) system enables continuous improvement of the driver behavior classification model through human-labeled data.
|
| 120 |
+
|
| 121 |
+
### Key Features
|
| 122 |
+
- **Original Dataset Tracking**: Automatically links labeled data to original datasets
|
| 123 |
+
- **Preference Learning**: Learns from differences between model predictions and human labels
|
| 124 |
+
- **Semantic Versioning**: Automatic model versioning (1.0 → 1.1 → 1.2 → 2.0)
|
| 125 |
+
- **Hugging Face Integration**: Saves models to HF Hub with metadata
|
| 126 |
+
- **Training Tracking**: Prevents retraining on the same datasets
|
| 127 |
+
|
| 128 |
+
### Usage Examples
|
| 129 |
+
|
| 130 |
+
#### Trigger RLHF Training
|
| 131 |
+
```bash
|
| 132 |
+
curl -X POST "http://localhost:8000/rlhf/train" \
|
| 133 |
+
-H "Content-Type: application/json" \
|
| 134 |
+
-d '{
|
| 135 |
+
"max_datasets": 5,
|
| 136 |
+
"force_retrain": false
|
| 137 |
+
}'
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
#### Check Training Status
|
| 141 |
+
```bash
|
| 142 |
+
curl -X GET "http://localhost:8000/rlhf/status"
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
#### List Trained Datasets
|
| 146 |
+
```bash
|
| 147 |
+
curl -X GET "http://localhost:8000/rlhf/trained-datasets"
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Data Flow
|
| 151 |
+
1. **Human Labeling**: Data labeled and stored in `skyledge/labeled/`
|
| 152 |
+
2. **Filename Convention**: `001_raw-002_2025-09-19-labelled.csv`
|
| 153 |
+
3. **Original Dataset**: Automatically loads `skyledge/raw/002_2025-09-19-raw.csv`
|
| 154 |
+
4. **RLHF Training**: Compares model predictions vs human labels
|
| 155 |
+
5. **Model Update**: Trains new model with preference learning
|
| 156 |
+
6. **Versioning**: Saves as v1.0, v1.1, etc. to Hugging Face Hub
|
| 157 |
+
|
| 158 |
+
## Documentation
|
| 159 |
+
|
| 160 |
+
- **MongoDB Setup**: See `MONGODB_SETUP.md` for detailed configuration
|
| 161 |
+
- **Google Drive Setup**: See `GOOGLE_DRIVE_SETUP.md` for configuration
|
| 162 |
+
- **RLHF Training**: See `train/README.md` for detailed RLHF documentation
|
| 163 |
+
- **API Reference**: Interactive docs at `/docs` endpoint
|
| 164 |
+
- **Code Structure**: Modular design for easy maintenance
|
| 165 |
+
|
| 166 |
+
## Development
|
| 167 |
+
|
| 168 |
+
The codebase follows clean architecture principles:
|
| 169 |
+
- **Separation of concerns**: Between storage, processing, API, and ML layers
|
| 170 |
+
- **Comprehensive error handling**: Graceful fallbacks for service unavailability
|
| 171 |
+
- **Type hints and documentation**: Full type annotations and docstrings
|
| 172 |
+
- **Modular design**: Easy to extend and maintain
|
| 173 |
+
- **RLHF Integration**: Seamless integration of machine learning with data processing
|
| 174 |
+
- **Version control**: Semantic versioning for model artifacts
|
| 175 |
+
- **Testing**: Comprehensive test coverage for all components
|
| 176 |
+
|
| 177 |
+
## Model Management
|
| 178 |
+
|
| 179 |
+
### Driver Behavior Classification
|
| 180 |
+
- **Model Type**: XGBoost Classifier
|
| 181 |
+
- **Labels**: Aggressive, Normal, Conservative
|
| 182 |
+
- **Features**: OBD sensor data (speed, RPM, throttle, etc.)
|
| 183 |
+
- **Training**: RLHF with human feedback integration
|
| 184 |
+
|
| 185 |
+
### Model Artifacts
|
| 186 |
+
- **XGBoost Model**: `xgb_drivestyle_ul.pkl`
|
| 187 |
+
- **Label Encoder**: `label_encoder_ul.pkl`
|
| 188 |
+
- **Feature Scaler**: `scaler_ul.pkl`
|
| 189 |
+
- **Metadata**: Training logs and performance metrics
|
| 190 |
+
|
| 191 |
+
### Versioning Strategy
|
| 192 |
+
- **Semantic Versioning**: 1.0 → 1.1 → 1.2 → 2.0
|
| 193 |
+
- **Automatic Detection**: Checks existing versions in HF repo
|
| 194 |
+
- **Fallback**: Timestamp-based versioning if HF unavailable
|
| 195 |
+
- **Local Backup**: Saves to local `/app/models/ul/v{version}/`
|
| 196 |
+
|
| 197 |
+
## License
|
| 198 |
+
|
| 199 |
+
Apache 2.0 License
|
app.py
ADDED
|
@@ -0,0 +1,802 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Access: https://binkhoale1812-obd-logger.hf.space/ui
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# ───────────── Installation ─────────────
|
| 5 |
+
# Router
|
| 6 |
+
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException
|
| 7 |
+
from fastapi.responses import FileResponse, HTMLResponse
|
| 8 |
+
from fastapi.staticfiles import StaticFiles
|
| 9 |
+
from fastapi.templating import Jinja2Templates
|
| 10 |
+
from fastapi.requests import Request
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
# ML/DL
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 18 |
+
from sklearn.impute import KNNImputer
|
| 19 |
+
# Utils
|
| 20 |
+
import os, datetime, json, logging, re
|
| 21 |
+
from datetime import timedelta
|
| 22 |
+
import pathlib
|
| 23 |
+
|
| 24 |
+
# Drive
|
| 25 |
+
from data.drive_saver import DriveSaver, get_drive_service, upload_to_folder
|
| 26 |
+
|
| 27 |
+
# Database
|
| 28 |
+
from data.mongo_saver import MongoSaver, save_csv_to_mongo, save_dataframe_to_mongo, MONGODB_AVAILABLE
|
| 29 |
+
from data.firebase_saver import FirebaseSaver, save_csv_increment, save_dataframe_increment
|
| 30 |
+
|
| 31 |
+
# UL Model
|
| 32 |
+
from utils.ul_label import ULLabeler
|
| 33 |
+
|
| 34 |
+
# RLHF Training
|
| 35 |
+
from train import RLHFTrainer
|
| 36 |
+
|
| 37 |
+
# ───────────── Logging Setup ─────────────
|
| 38 |
+
logger = logging.getLogger("obd-logger")
|
| 39 |
+
logger.setLevel(logging.INFO)
|
| 40 |
+
fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
|
| 41 |
+
handler = logging.StreamHandler()
|
| 42 |
+
handler.setFormatter(fmt)
|
| 43 |
+
logger.addHandler(handler)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ───────────── FastAPI Init ─────────────
|
| 47 |
+
app = FastAPI(title="OBD-II Logging & Processing API")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ───────────── Directory Paths ─────────────
|
| 51 |
+
APP_ROOT = pathlib.Path(__file__).parent.resolve() # Absolute base dir
|
| 52 |
+
BASE_DIR = os.path.join(APP_ROOT, './cache/obd_data')
|
| 53 |
+
CLEANED_DIR = os.path.join(BASE_DIR, "cleaned")
|
| 54 |
+
PLOT_DIR = os.path.join(BASE_DIR, "plots")
|
| 55 |
+
RAW_CSV = os.path.join(BASE_DIR, "raw_logs.csv")
|
| 56 |
+
os.makedirs(BASE_DIR, exist_ok=True)
|
| 57 |
+
os.makedirs(CLEANED_DIR, exist_ok=True)
|
| 58 |
+
os.makedirs(PLOT_DIR, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
DRIVE_STYLE = [] # latest UL predictions (string labels) — overwritten each run
|
| 61 |
+
|
| 62 |
+
# Init temp empty file
|
| 63 |
+
if not os.path.exists(RAW_CSV):
|
| 64 |
+
pd.DataFrame(columns=["timestamp", "driving_style"]).to_csv(RAW_CSV, index=False)
|
| 65 |
+
|
| 66 |
+
PIPELINE_EVENTS = {}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ───────────── Drive & Database Services ─────────────
|
| 70 |
+
# Initialize services
|
| 71 |
+
drive_saver = DriveSaver()
|
| 72 |
+
mongo_saver = MongoSaver()
|
| 73 |
+
firebase_saver = FirebaseSaver()
|
| 74 |
+
|
| 75 |
+
# ───────────── Model Download on Startup ─────────────
|
| 76 |
+
@app.on_event("startup")
|
| 77 |
+
async def startup_event():
|
| 78 |
+
"""Download models on app startup"""
|
| 79 |
+
try:
|
| 80 |
+
logger.info("🚀 Starting model download...")
|
| 81 |
+
from utils.download import download_latest_models
|
| 82 |
+
|
| 83 |
+
# Load .env file if it exists
|
| 84 |
+
env_path = pathlib.Path(".env")
|
| 85 |
+
if env_path.exists():
|
| 86 |
+
logger.info("📄 Loading .env file...")
|
| 87 |
+
with open(env_path, 'r') as f:
|
| 88 |
+
for line in f:
|
| 89 |
+
line = line.strip()
|
| 90 |
+
if line and not line.startswith('#') and '=' in line:
|
| 91 |
+
key, value = line.split('=', 1)
|
| 92 |
+
os.environ[key] = value
|
| 93 |
+
|
| 94 |
+
# Download models
|
| 95 |
+
success = download_latest_models()
|
| 96 |
+
if success:
|
| 97 |
+
logger.info("✅ Models downloaded successfully on startup")
|
| 98 |
+
else:
|
| 99 |
+
logger.warning("⚠️ Model download failed on startup - some features may not work")
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"❌ Startup model download failed: {e}")
|
| 103 |
+
logger.warning("⚠️ Continuing without models - some features may not work")
|
| 104 |
+
|
| 105 |
+
# ───────────── Render Dashboard UI ──────────────
|
| 106 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 107 |
+
app.mount("/plots", StaticFiles(directory=str(PLOT_DIR)), name="plots")
|
| 108 |
+
templates = Jinja2Templates(directory="static")
|
| 109 |
+
# Endpoint
|
| 110 |
+
@app.get("/ui", response_class=HTMLResponse)
|
| 111 |
+
def dashboard(request: Request):
|
| 112 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ───────────── Streamed Entry Ingest ─────────────
|
| 116 |
+
class OBDEntry(BaseModel):
|
| 117 |
+
timestamp: str
|
| 118 |
+
driving_style: str
|
| 119 |
+
data: dict
|
| 120 |
+
status: str = None # Optional for control signal (start/end streaming)
|
| 121 |
+
|
| 122 |
+
# Direct centralized timestamp format
|
| 123 |
+
def normalize_timestamp(ts):
|
| 124 |
+
return ts.replace(":", "-").replace(".", "-").replace(" ", "T").replace("/", "-")
|
| 125 |
+
|
| 126 |
+
# Real time endpoint
|
| 127 |
+
@app.post("/ingest")
|
| 128 |
+
def ingest(entry: OBDEntry, background_tasks: BackgroundTasks):
|
| 129 |
+
norm_ts = normalize_timestamp(entry.timestamp)
|
| 130 |
+
logger.info(f"Ingest received: {norm_ts} | Status: {entry.status}")
|
| 131 |
+
# Start logging
|
| 132 |
+
if entry.status == "start":
|
| 133 |
+
PIPELINE_EVENTS[norm_ts] = {"status": "started", "time": norm_ts}
|
| 134 |
+
return {"status": "started"}
|
| 135 |
+
# End logging, start processing
|
| 136 |
+
if entry.status == "end":
|
| 137 |
+
background_tasks.add_task(process_data, norm_ts)
|
| 138 |
+
return {"status": "processed"}
|
| 139 |
+
# Normal row append
|
| 140 |
+
try:
|
| 141 |
+
df = pd.read_csv(RAW_CSV)
|
| 142 |
+
row = {"timestamp": norm_ts, "driving_style": entry.driving_style}
|
| 143 |
+
row.update(entry.data)
|
| 144 |
+
df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
|
| 145 |
+
df.to_csv(RAW_CSV, index=False)
|
| 146 |
+
return {"status": "row appended"}
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Streaming ingest failed: {e}")
|
| 149 |
+
raise HTTPException(status_code=500, detail="Ingest error")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ───────────── Bulk CSV Upload ───────────────────
|
| 153 |
+
@app.post("/upload-csv/")
|
| 154 |
+
async def upload_csv(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
|
| 155 |
+
ts = datetime.datetime.now().isoformat()
|
| 156 |
+
norm_ts = normalize_timestamp(ts)
|
| 157 |
+
path = os.path.join(BASE_DIR, file.filename)
|
| 158 |
+
PIPELINE_EVENTS[norm_ts] = {"status": "started", "time": norm_ts}
|
| 159 |
+
with open(path, "wb") as f:
|
| 160 |
+
f.write(await file.read())
|
| 161 |
+
logger.info(f"CSV uploaded: {path}")
|
| 162 |
+
background_tasks.add_task(process_uploaded_csv, path, norm_ts)
|
| 163 |
+
return {"status": "processing started", "file": file.filename}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ───────────── Data Processing ──────────────────
|
| 167 |
+
# Bulk CSV
|
| 168 |
+
def process_uploaded_csv(path, norm_ts):
|
| 169 |
+
try:
|
| 170 |
+
df = pd.read_csv(path, parse_dates=["timestamp"])
|
| 171 |
+
PIPELINE_EVENTS[norm_ts] = {
|
| 172 |
+
"status": "processed",
|
| 173 |
+
"time": norm_ts
|
| 174 |
+
}
|
| 175 |
+
_process_and_save(df, norm_ts)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"CSV processing failed: {e}")
|
| 178 |
+
|
| 179 |
+
# Process streaming
|
| 180 |
+
def process_data(norm_ts):
|
| 181 |
+
try:
|
| 182 |
+
df = pd.read_csv(RAW_CSV, parse_dates=["timestamp"])
|
| 183 |
+
PIPELINE_EVENTS[norm_ts] = {
|
| 184 |
+
"status": "processed",
|
| 185 |
+
"time": norm_ts
|
| 186 |
+
}
|
| 187 |
+
_process_and_save(df, norm_ts)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.error(f"Streamed data processing failed: {e}")
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# All processing pipeline
|
| 193 |
+
def _process_and_save(df, norm_ts):
|
| 194 |
+
"""
|
| 195 |
+
Gap-aware, multi-sensor backfill for OBD-II streams with unknown cadence.
|
| 196 |
+
- Infers sampling interval from data (robust).
|
| 197 |
+
- Inserts placeholder rows for gaps using the inferred interval.
|
| 198 |
+
- Flags only corrupted values (NaN/inf/sentinels); does NOT trim 'extreme but plausible' outliers.
|
| 199 |
+
- Backfills ALL numeric sensors with KNNImputer (+ time as a feature).
|
| 200 |
+
- Keeps your plotting, Drive upload, and PIPELINE_EVENTS wiring intact.
|
| 201 |
+
"""
|
| 202 |
+
logger.info("🔧 Cleaning started (auto-interval, KNN for all sensors)")
|
| 203 |
+
|
| 204 |
+
# ----------------------- helpers (scoped locally) -----------------------
|
| 205 |
+
protected_cols = {"timestamp", "driving_style"}
|
| 206 |
+
SENTINELS = {-22, -40, 255}
|
| 207 |
+
|
| 208 |
+
def _to_dt(_df: pd.DataFrame) -> pd.DataFrame:
|
| 209 |
+
_df = _df.copy()
|
| 210 |
+
_df["timestamp"] = pd.to_datetime(_df["timestamp"], errors="coerce", utc=True)
|
| 211 |
+
_df = _df.dropna(subset=["timestamp"]).sort_values("timestamp").reset_index(drop=True)
|
| 212 |
+
# drop exact duplicate timestamps (keep first)
|
| 213 |
+
_df = _df[~_df["timestamp"].duplicated(keep="first")].reset_index(drop=True)
|
| 214 |
+
return _df
|
| 215 |
+
|
| 216 |
+
def _drop_dead_weight(_df: pd.DataFrame) -> pd.DataFrame:
|
| 217 |
+
_df = _df.copy()
|
| 218 |
+
# drop all-NaN or constant columns (except protected)
|
| 219 |
+
drop_cols = [c for c in _df.columns
|
| 220 |
+
if c not in protected_cols and (_df[c].nunique(dropna=True) <= 1 or _df[c].isna().all())]
|
| 221 |
+
if drop_cols:
|
| 222 |
+
_df.drop(columns=drop_cols, inplace=True, errors="ignore")
|
| 223 |
+
# drop duplicate columns
|
| 224 |
+
_df = _df.loc[:, ~_df.T.duplicated()]
|
| 225 |
+
# drop duplicate rows
|
| 226 |
+
_df.drop_duplicates(inplace=True)
|
| 227 |
+
return _df
|
| 228 |
+
|
| 229 |
+
def _normalize_corruption(_df: pd.DataFrame) -> pd.DataFrame:
|
| 230 |
+
_df = _df.copy()
|
| 231 |
+
# normalize obvious corruptions: NaN/inf/sentinels → NaN
|
| 232 |
+
_df.replace(list(SENTINELS), np.nan, inplace=True)
|
| 233 |
+
num_cols = _df.select_dtypes(include=[np.number]).columns
|
| 234 |
+
for c in num_cols:
|
| 235 |
+
s = _df[c]
|
| 236 |
+
s = s.astype(float)
|
| 237 |
+
s[~np.isfinite(s)] = np.nan
|
| 238 |
+
_df[c] = s
|
| 239 |
+
return _df
|
| 240 |
+
|
| 241 |
+
def _light_row_col_filters(_df: pd.DataFrame) -> pd.DataFrame:
|
| 242 |
+
_df = _df.copy()
|
| 243 |
+
# keep rows with <=80% NaN (excluding timestamp)
|
| 244 |
+
if "timestamp" in _df.columns and _df.shape[1] > 1:
|
| 245 |
+
keep = _df.drop(columns=["timestamp"]).isna().mean(axis=1) <= 0.8
|
| 246 |
+
_df = _df[keep]
|
| 247 |
+
# prune columns with >80% NaN (except protected)
|
| 248 |
+
na_frac = _df.isna().mean(numeric_only=False)
|
| 249 |
+
high_na = [c for c in na_frac.index if na_frac[c] > 0.8 and c not in protected_cols]
|
| 250 |
+
if high_na:
|
| 251 |
+
_df.drop(columns=high_na, inplace=True, errors="ignore")
|
| 252 |
+
# keep rows that have >1 observed value across non-timestamp columns
|
| 253 |
+
if "timestamp" in _df.columns and _df.shape[1] > 1:
|
| 254 |
+
valid = _df.drop(columns=["timestamp"]).notna().sum(axis=1) > 1
|
| 255 |
+
_df = _df[valid]
|
| 256 |
+
return _df
|
| 257 |
+
|
| 258 |
+
def _infer_base_interval_seconds(ts: pd.Series) -> float:
|
| 259 |
+
"""
|
| 260 |
+
Robustly infer base cadence from timestamp diffs.
|
| 261 |
+
Strategy:
|
| 262 |
+
- take positive diffs
|
| 263 |
+
- winsorize to 5–95% to reduce impact of long gaps
|
| 264 |
+
- compute a 'rounded mode' on 10ms grid; fall back to median if needed
|
| 265 |
+
"""
|
| 266 |
+
if ts.size < 2:
|
| 267 |
+
return 1.0 # fallback
|
| 268 |
+
diffs = ts.sort_values().diff().dropna().dt.total_seconds()
|
| 269 |
+
diffs = diffs[diffs > 0]
|
| 270 |
+
if diffs.empty:
|
| 271 |
+
return 1.0
|
| 272 |
+
q05, q95 = diffs.quantile([0.05, 0.95])
|
| 273 |
+
core = diffs[(diffs >= q05) & (diffs <= q95)]
|
| 274 |
+
if core.empty:
|
| 275 |
+
core = diffs
|
| 276 |
+
# round to 10ms and take the most frequent bin
|
| 277 |
+
rounded = (core / 0.01).round() * 0.01
|
| 278 |
+
mode = rounded.mode()
|
| 279 |
+
if not mode.empty:
|
| 280 |
+
est = float(mode.iloc[0])
|
| 281 |
+
else:
|
| 282 |
+
est = float(core.median())
|
| 283 |
+
# guardrails
|
| 284 |
+
if est <= 0:
|
| 285 |
+
est = float(core.median())
|
| 286 |
+
logger.info(f"⏱️ Inferred base interval ≈ {est:.3f}s")
|
| 287 |
+
return est
|
| 288 |
+
|
| 289 |
+
def _insert_time_gaps(_df: pd.DataFrame, base_sec: float) -> pd.DataFrame:
|
| 290 |
+
"""
|
| 291 |
+
Insert placeholder rows at multiples of inferred base_sec when gaps exceed ~1.5× base.
|
| 292 |
+
All numeric columns are NaN in inserted rows; non-numeric are forward-filled (except protected).
|
| 293 |
+
"""
|
| 294 |
+
if _df.empty:
|
| 295 |
+
return _df
|
| 296 |
+
_df = _df.copy()
|
| 297 |
+
_df = _to_dt(_df)
|
| 298 |
+
expected = timedelta(seconds=base_sec)
|
| 299 |
+
# tolerance ~ half interval to avoid jittery inserts
|
| 300 |
+
tol = timedelta(seconds=0.5 * base_sec)
|
| 301 |
+
# Normalize data
|
| 302 |
+
num_cols = _df.select_dtypes(include=[np.number]).columns.tolist()
|
| 303 |
+
non_num_cols = [c for c in _df.columns if c not in num_cols]
|
| 304 |
+
# Missing detection on interval expectation
|
| 305 |
+
rows = [_df.iloc[0].copy()]
|
| 306 |
+
for i in range(1, len(_df)):
|
| 307 |
+
prev = _df.iloc[i - 1]
|
| 308 |
+
curr = _df.iloc[i]
|
| 309 |
+
dt = curr["timestamp"] - prev["timestamp"]
|
| 310 |
+
if dt > expected * 1.5 + tol:
|
| 311 |
+
n_missing = int(round(dt / expected)) - 1
|
| 312 |
+
if n_missing > 0:
|
| 313 |
+
for j in range(1, n_missing + 1):
|
| 314 |
+
gap = prev.copy()
|
| 315 |
+
gap["timestamp"] = prev["timestamp"] + j * expected
|
| 316 |
+
# numeric sensors left as NaN to be imputed
|
| 317 |
+
for c in num_cols:
|
| 318 |
+
if c not in protected_cols:
|
| 319 |
+
gap[c] = np.nan
|
| 320 |
+
# for non-numeric, keep last known (except protected)
|
| 321 |
+
for c in non_num_cols:
|
| 322 |
+
if c not in protected_cols:
|
| 323 |
+
gap[c] = prev[c]
|
| 324 |
+
rows.append(gap)
|
| 325 |
+
rows.append(curr.copy())
|
| 326 |
+
# Sorting
|
| 327 |
+
out = pd.DataFrame(rows).sort_values("timestamp").reset_index(drop=True)
|
| 328 |
+
return out
|
| 329 |
+
|
| 330 |
+
def _knn_impute_all(_df: pd.DataFrame) -> pd.DataFrame:
|
| 331 |
+
"""
|
| 332 |
+
Backfill ALL numeric sensors jointly with KNN, using time (ts_sec) as an additional feature.
|
| 333 |
+
"""
|
| 334 |
+
_df = _df.copy()
|
| 335 |
+
_df["ts_sec"] = (_df["timestamp"] - _df["timestamp"].min()).dt.total_seconds()
|
| 336 |
+
# Normalize data
|
| 337 |
+
num_cols = _df.select_dtypes(include=[np.number]).columns.tolist()
|
| 338 |
+
# ensure ts_sec included
|
| 339 |
+
if "ts_sec" not in num_cols:
|
| 340 |
+
num_cols = num_cols + ["ts_sec"]
|
| 341 |
+
# Build imputation frame and remember order
|
| 342 |
+
X = _df[num_cols].copy()
|
| 343 |
+
non_missing_rows = X.dropna().shape[0]
|
| 344 |
+
k = min(5, max(1, non_missing_rows))
|
| 345 |
+
logger.info(f"🤝 KNNImputer n_neighbors={k} on {len(num_cols)} features")
|
| 346 |
+
# Impute and backfill data using KNN
|
| 347 |
+
imputer = KNNImputer(n_neighbors=k)
|
| 348 |
+
X_imp = imputer.fit_transform(X)
|
| 349 |
+
X_imp = pd.DataFrame(X_imp, columns=num_cols, index=_df.index)
|
| 350 |
+
# Write back (excluding ts_sec)
|
| 351 |
+
for c in num_cols:
|
| 352 |
+
if c == "ts_sec":
|
| 353 |
+
continue
|
| 354 |
+
_df[c] = X_imp[c]
|
| 355 |
+
|
| 356 |
+
_df.drop(columns=["ts_sec"], inplace=True)
|
| 357 |
+
return _df
|
| 358 |
+
|
| 359 |
+
# Copy data from selective sensor types for Feature Engineering
|
| 360 |
+
def _feature_engineering(_df: pd.DataFrame) -> pd.DataFrame:
|
| 361 |
+
_df = _df.copy()
|
| 362 |
+
if {"ENGINE_LOAD", "ABSOLUTE_LOAD"}.issubset(_df.columns):
|
| 363 |
+
_df["AVG_ENGINE_LOAD"] = _df[["ENGINE_LOAD", "ABSOLUTE_LOAD"]].mean(axis=1)
|
| 364 |
+
if {"INTAKE_TEMP", "OIL_TEMP", "COOLANT_TEMP"}.issubset(_df.columns):
|
| 365 |
+
_df["TEMP_MEAN"] = _df[["INTAKE_TEMP", "OIL_TEMP", "COOLANT_TEMP"]].mean(axis=1)
|
| 366 |
+
if {"MAF", "RPM"}.issubset(_df.columns):
|
| 367 |
+
_df["AIRFLOW_PER_RPM"] = _df["MAF"] / _df["RPM"].replace(0, np.nan)
|
| 368 |
+
return _df
|
| 369 |
+
|
| 370 |
+
# Apply MinMaxScaler to fit data frame
|
| 371 |
+
def _scale_numeric(_df: pd.DataFrame) -> pd.DataFrame:
|
| 372 |
+
_df = _df.copy()
|
| 373 |
+
num_cols = _df.select_dtypes(include=[np.number]).columns.tolist()
|
| 374 |
+
for c in list(protected_cols):
|
| 375 |
+
if c in num_cols:
|
| 376 |
+
num_cols.remove(c)
|
| 377 |
+
if num_cols:
|
| 378 |
+
scaler = MinMaxScaler()
|
| 379 |
+
_df[num_cols] = scaler.fit_transform(_df[num_cols])
|
| 380 |
+
return _df
|
| 381 |
+
|
| 382 |
+
# Correlation heatmap plotter
|
| 383 |
+
def _plot_corr(_df: pd.DataFrame, _id: str):
|
| 384 |
+
try:
|
| 385 |
+
num = _df.select_dtypes(include=[np.number])
|
| 386 |
+
if num.shape[1] < 2:
|
| 387 |
+
return
|
| 388 |
+
plt.figure(figsize=(12, 10))
|
| 389 |
+
sns.heatmap(num.corr(), annot=True, fmt=".2f", cmap="coolwarm")
|
| 390 |
+
plt.title("Correlation Between Numeric OBD-II Variables")
|
| 391 |
+
plt.tight_layout()
|
| 392 |
+
plt.savefig(os.path.join(PLOT_DIR, f"heatmap_{_id}.png"))
|
| 393 |
+
plt.close()
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(f"Heatmap generation failed: {e}")
|
| 396 |
+
|
| 397 |
+
# Sensor trend plotter
|
| 398 |
+
def _plot_trend(_df: pd.DataFrame, _id: str):
|
| 399 |
+
try:
|
| 400 |
+
plt.figure(figsize=(15, 6))
|
| 401 |
+
for col in ['RPM', 'ENGINE_LOAD', 'ABSOLUTE_LOAD', 'COOLANT_TEMP',
|
| 402 |
+
'INTAKE_TEMP', 'OIL_TEMP', 'INTAKE_PRESSURE', 'BAROMETRIC_PRESSURE',
|
| 403 |
+
'CONTROL_MODULE_VOLTAGE']:
|
| 404 |
+
if col in _df.columns:
|
| 405 |
+
plt.plot(_df.index, _df[col], label=col)
|
| 406 |
+
plt.title("Sensor Trends (Index-Based, No Time Gaps)")
|
| 407 |
+
plt.xlabel("Sample Index")
|
| 408 |
+
plt.ylabel("Sensor Value")
|
| 409 |
+
plt.legend()
|
| 410 |
+
plt.grid(True)
|
| 411 |
+
plt.tight_layout()
|
| 412 |
+
plt.savefig(os.path.join(PLOT_DIR, f"trend_{_id}.png"))
|
| 413 |
+
plt.close()
|
| 414 |
+
except Exception as e:
|
| 415 |
+
logger.error(f"Trend plot failed: {e}")
|
| 416 |
+
|
| 417 |
+
# ----------------------- pipeline -----------------------
|
| 418 |
+
df = df.copy()
|
| 419 |
+
# 0) Basic tidy
|
| 420 |
+
df = _drop_dead_weight(df)
|
| 421 |
+
df = _to_dt(df)
|
| 422 |
+
# 1) Corruption-only normalization (no outlier trimming)
|
| 423 |
+
df = _normalize_corruption(df)
|
| 424 |
+
# 2) Light row/column filtering for extreme sparsity
|
| 425 |
+
df = _light_row_col_filters(df)
|
| 426 |
+
# 3) Auto infer base interval & insert gap rows
|
| 427 |
+
base_sec = _infer_base_interval_seconds(df["timestamp"])
|
| 428 |
+
df = _insert_time_gaps(df, base_sec)
|
| 429 |
+
# 4) KNN backfill all numeric sensors (time-aware)
|
| 430 |
+
df = _knn_impute_all(df)
|
| 431 |
+
# 5) Feature engineering AFTER imputation
|
| 432 |
+
df = _feature_engineering(df)
|
| 433 |
+
# 6) Final sort / index
|
| 434 |
+
df.sort_values("timestamp", inplace=True)
|
| 435 |
+
df.reset_index(drop=True, inplace=True)
|
| 436 |
+
# 7) Scaling after impute (kept from original)
|
| 437 |
+
if not df.select_dtypes(include=["number"]).empty:
|
| 438 |
+
df = _scale_numeric(df)
|
| 439 |
+
# 8) Save
|
| 440 |
+
out_path = os.path.join(CLEANED_DIR, f"cleaned_{norm_ts}.csv")
|
| 441 |
+
df.to_csv(out_path, index=False)
|
| 442 |
+
logger.info(f"✅ Cleaned saved: {out_path}")
|
| 443 |
+
# 9) UL drivestyle predictions
|
| 444 |
+
df_for_persist = df
|
| 445 |
+
labeled_path = None
|
| 446 |
+
try:
|
| 447 |
+
ul = ULLabeler.get()
|
| 448 |
+
preds = ul.predict_df(df)
|
| 449 |
+
# update global DRIVE_STYLE (overwrite if already exists)
|
| 450 |
+
global DRIVE_STYLE
|
| 451 |
+
DRIVE_STYLE = [str(p) for p in preds]
|
| 452 |
+
# write labeled CSV (driving_style column)
|
| 453 |
+
df_labeled = df.copy()
|
| 454 |
+
df_labeled["driving_style"] = DRIVE_STYLE
|
| 455 |
+
labeled_path = os.path.join(CLEANED_DIR, f"cleaned_{norm_ts}_labeled.csv")
|
| 456 |
+
df_labeled.to_csv(labeled_path, index=False)
|
| 457 |
+
df_for_persist = df_labeled
|
| 458 |
+
# Update the global DRIVE_STYLE list
|
| 459 |
+
logger.info(f"✅ UL labels generated ({len(DRIVE_STYLE)}) → {labeled_path}")
|
| 460 |
+
except Exception as e:
|
| 461 |
+
logger.error(f"❌ UL labeling failed: {e}")
|
| 462 |
+
# 10) Plots
|
| 463 |
+
_plot_corr(df, norm_ts)
|
| 464 |
+
_plot_trend(df, norm_ts)
|
| 465 |
+
# 11) Update event
|
| 466 |
+
try:
|
| 467 |
+
PIPELINE_EVENTS[norm_ts]["status"] = "done"
|
| 468 |
+
except Exception:
|
| 469 |
+
pass
|
| 470 |
+
# 12) Upload to Drive
|
| 471 |
+
try:
|
| 472 |
+
if drive_saver.is_service_available():
|
| 473 |
+
if labeled_path and os.path.exists(labeled_path):
|
| 474 |
+
drive_saver.upload_csv_to_drive(labeled_path)
|
| 475 |
+
logger.info("✅ Uploaded labeled to Google Drive")
|
| 476 |
+
else:
|
| 477 |
+
drive_saver.upload_csv_to_drive(out_path)
|
| 478 |
+
logger.info("✅ Uploaded default to Google Drive")
|
| 479 |
+
else:
|
| 480 |
+
logger.warning("⚠️ Google Drive service not available")
|
| 481 |
+
except Exception as e:
|
| 482 |
+
logger.error(f"❌ Drive upload error: {e}")
|
| 483 |
+
# 13) Save to MongoDB
|
| 484 |
+
try:
|
| 485 |
+
if mongo_saver.is_connected():
|
| 486 |
+
# Save the cleaned DataFrame directly to MongoDB
|
| 487 |
+
session_id = f"session_{norm_ts}"
|
| 488 |
+
if mongo_saver.save_dataframe_to_mongo(df_for_persist, session_id):
|
| 489 |
+
logger.info("✅ Saved to MongoDB")
|
| 490 |
+
else:
|
| 491 |
+
logger.warning("⚠️ MongoDB save failed")
|
| 492 |
+
else:
|
| 493 |
+
logger.warning("⚠️ MongoDB not connected")
|
| 494 |
+
except Exception as e:
|
| 495 |
+
logger.error(f"❌ MongoDB save error: {e}")
|
| 496 |
+
# 14) Save to Firebase Storage (incremented NNN_YYYY-MM-DD_processed.csv at fixed path)
|
| 497 |
+
try:
|
| 498 |
+
if firebase_saver and firebase_saver.is_available():
|
| 499 |
+
# Choose the final artifact to persist
|
| 500 |
+
if labeled_path and os.path.exists(labeled_path):
|
| 501 |
+
target_path = labeled_path
|
| 502 |
+
else:
|
| 503 |
+
target_path = out_path
|
| 504 |
+
# Optional: use the acquisition date if norm_ts starts with YYYY-MM-DD, else let saver use AUS/Melbourne "today"
|
| 505 |
+
date_str = None
|
| 506 |
+
try:
|
| 507 |
+
date_str = str(norm_ts)[:10] if norm_ts and len(str(norm_ts)) >= 10 else None
|
| 508 |
+
except Exception:
|
| 509 |
+
date_str = None
|
| 510 |
+
# Upload with auto-incremented name: NNN_YYYY-MM-DD_processed.csv under skyledge/processed
|
| 511 |
+
gs_url = firebase_saver.upload_file_with_increment(target_path, date_str=date_str)
|
| 512 |
+
# Save to Firebase Storage (incremented NNN_YYYY-MM-DD_processed.csv at fixed path)
|
| 513 |
+
if gs_url:
|
| 514 |
+
logger.info(f"✅ Saved to Firebase Storage: {gs_url}")
|
| 515 |
+
else:
|
| 516 |
+
logger.warning("⚠️ Firebase Storage upload returned empty URL")
|
| 517 |
+
else:
|
| 518 |
+
logger.warning("⚠️ Firebase Storage not available")
|
| 519 |
+
except Exception as e:
|
| 520 |
+
logger.error(f"❌ Firebase Storage save error: {e}")
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# ───────────── Health Check ──────────────────────
|
| 525 |
+
@app.get("/health")
|
| 526 |
+
def health():
|
| 527 |
+
return {"status": "ok"}
|
| 528 |
+
|
| 529 |
+
@app.get("/models/status")
|
| 530 |
+
def models_status():
|
| 531 |
+
"""Check if models are loaded and available"""
|
| 532 |
+
try:
|
| 533 |
+
model_dir = pathlib.Path(os.getenv("MODEL_DIR", "/app/models/ul"))
|
| 534 |
+
required_files = ["label_encoder_ul.pkl", "scaler_ul.pkl", "xgb_drivestyle_ul.pkl"]
|
| 535 |
+
|
| 536 |
+
available_files = []
|
| 537 |
+
missing_files = []
|
| 538 |
+
|
| 539 |
+
for file in required_files:
|
| 540 |
+
file_path = model_dir / file
|
| 541 |
+
if file_path.exists():
|
| 542 |
+
available_files.append(file)
|
| 543 |
+
else:
|
| 544 |
+
missing_files.append(file)
|
| 545 |
+
|
| 546 |
+
status = "ready" if len(available_files) == len(required_files) else "loading"
|
| 547 |
+
|
| 548 |
+
return {
|
| 549 |
+
"status": status,
|
| 550 |
+
"model_directory": str(model_dir),
|
| 551 |
+
"available_files": available_files,
|
| 552 |
+
"missing_files": missing_files,
|
| 553 |
+
"total_files": len(required_files),
|
| 554 |
+
"loaded_files": len(available_files)
|
| 555 |
+
}
|
| 556 |
+
except Exception as e:
|
| 557 |
+
return {
|
| 558 |
+
"status": "error",
|
| 559 |
+
"error": str(e),
|
| 560 |
+
"timestamp": datetime.now().isoformat()
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# ─────── Send status to frontend ─────────────────
|
| 565 |
+
@app.get("/events")
|
| 566 |
+
def get_events():
|
| 567 |
+
return PIPELINE_EVENTS
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# ────── Delete event from dashboard ──────────────
|
| 571 |
+
@app.delete("/events/remove/{timestamp}")
|
| 572 |
+
def remove_event(timestamp: str):
|
| 573 |
+
if timestamp in PIPELINE_EVENTS:
|
| 574 |
+
del PIPELINE_EVENTS[timestamp]
|
| 575 |
+
return {"status": "deleted"}
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
# ───────────── Download Cleaned ──────────────────
|
| 579 |
+
@app.get("/download/{filename}")
|
| 580 |
+
def download_file(filename: str):
|
| 581 |
+
path = os.path.join(CLEANED_DIR, filename)
|
| 582 |
+
if not os.path.exists(path):
|
| 583 |
+
raise HTTPException(status_code=404, detail="Not found")
|
| 584 |
+
return FileResponse(path, media_type='text/csv', filename=filename)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# ───────────── MongoDB Operations ──────────────────
|
| 588 |
+
@app.get("/mongo/status")
|
| 589 |
+
def mongo_status():
|
| 590 |
+
"""Check MongoDB connection status"""
|
| 591 |
+
return {
|
| 592 |
+
"connected": mongo_saver.is_connected(),
|
| 593 |
+
"available": MONGODB_AVAILABLE if 'MONGODB_AVAILABLE' in globals() else False
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@app.get("/mongo/sessions")
|
| 598 |
+
def get_mongo_sessions():
|
| 599 |
+
"""Get summary of all MongoDB sessions"""
|
| 600 |
+
if not mongo_saver.is_connected():
|
| 601 |
+
raise HTTPException(status_code=503, detail="MongoDB not connected")
|
| 602 |
+
|
| 603 |
+
sessions = mongo_saver.get_session_summary()
|
| 604 |
+
return {"sessions": sessions}
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
@app.get("/mongo/query")
|
| 608 |
+
def query_mongo_data(
|
| 609 |
+
session_id: str = None,
|
| 610 |
+
driving_style: str = None,
|
| 611 |
+
start_time: str = None,
|
| 612 |
+
end_time: str = None,
|
| 613 |
+
limit: int = 1000
|
| 614 |
+
):
|
| 615 |
+
"""Query data from MongoDB with filters"""
|
| 616 |
+
if not mongo_saver.is_connected():
|
| 617 |
+
raise HTTPException(status_code=503, detail="MongoDB not connected")
|
| 618 |
+
|
| 619 |
+
# Parse datetime strings if provided
|
| 620 |
+
start_dt = None
|
| 621 |
+
end_dt = None
|
| 622 |
+
|
| 623 |
+
if start_time:
|
| 624 |
+
try:
|
| 625 |
+
start_dt = pd.to_datetime(start_time)
|
| 626 |
+
except Exception:
|
| 627 |
+
raise HTTPException(status_code=400, detail="Invalid start_time format")
|
| 628 |
+
|
| 629 |
+
if end_time:
|
| 630 |
+
try:
|
| 631 |
+
end_dt = pd.to_datetime(end_time)
|
| 632 |
+
except Exception:
|
| 633 |
+
raise HTTPException(status_code=400, detail="Invalid end_time format")
|
| 634 |
+
|
| 635 |
+
results = mongo_saver.query_data(
|
| 636 |
+
session_id=session_id,
|
| 637 |
+
driving_style=driving_style,
|
| 638 |
+
start_time=start_dt,
|
| 639 |
+
end_time=end_dt,
|
| 640 |
+
limit=limit
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
return {"results": results, "count": len(results)}
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@app.post("/mongo/save-csv")
|
| 647 |
+
async def save_csv_to_mongo_endpoint(
|
| 648 |
+
file: UploadFile = File(...),
|
| 649 |
+
session_id: str = None
|
| 650 |
+
):
|
| 651 |
+
"""Save uploaded CSV directly to MongoDB"""
|
| 652 |
+
if not mongo_saver.is_connected():
|
| 653 |
+
raise HTTPException(status_code=503, detail="MongoDB not connected")
|
| 654 |
+
|
| 655 |
+
try:
|
| 656 |
+
# Save uploaded file temporarily
|
| 657 |
+
temp_path = os.path.join(BASE_DIR, f"temp_{file.filename}")
|
| 658 |
+
with open(temp_path, "wb") as f:
|
| 659 |
+
f.write(await file.read())
|
| 660 |
+
|
| 661 |
+
# Save to MongoDB
|
| 662 |
+
success = mongo_saver.save_csv_to_mongo(temp_path, session_id)
|
| 663 |
+
|
| 664 |
+
# Clean up temp file
|
| 665 |
+
if os.path.exists(temp_path):
|
| 666 |
+
os.remove(temp_path)
|
| 667 |
+
|
| 668 |
+
if success:
|
| 669 |
+
return {"status": "success", "message": "CSV saved to MongoDB"}
|
| 670 |
+
else:
|
| 671 |
+
raise HTTPException(status_code=500, detail="Failed to save to MongoDB")
|
| 672 |
+
|
| 673 |
+
except Exception as e:
|
| 674 |
+
logger.error(f"CSV to MongoDB save failed: {e}")
|
| 675 |
+
raise HTTPException(status_code=500, detail=f"Save failed: {str(e)}")
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# ───────────── RLHF Training Endpoints ─────────────
|
| 679 |
+
|
| 680 |
+
class RLHFTrainingRequest(BaseModel):
|
| 681 |
+
max_datasets: int = 10
|
| 682 |
+
force_retrain: bool = False
|
| 683 |
+
|
| 684 |
+
class RLHFTrainingResponse(BaseModel):
|
| 685 |
+
status: str
|
| 686 |
+
model_version: str = None
|
| 687 |
+
datasets_processed: int = 0
|
| 688 |
+
samples_processed: int = 0
|
| 689 |
+
performance_metrics: dict = None
|
| 690 |
+
error: str = None
|
| 691 |
+
timestamp: str = None
|
| 692 |
+
|
| 693 |
+
@app.post("/rlhf/train", response_model=RLHFTrainingResponse)
|
| 694 |
+
async def trigger_rlhf_training(
|
| 695 |
+
request: RLHFTrainingRequest,
|
| 696 |
+
background_tasks: BackgroundTasks
|
| 697 |
+
):
|
| 698 |
+
"""
|
| 699 |
+
Trigger RLHF (Reinforcement Learning from Human Feedback) training session.
|
| 700 |
+
|
| 701 |
+
This endpoint:
|
| 702 |
+
1. Loads human-labeled data from Firebase storage (skyledge/labeled)
|
| 703 |
+
2. Combines it with existing model predictions for RLHF
|
| 704 |
+
3. Retrains the XGBoost model with the combined dataset
|
| 705 |
+
4. Saves the new model to Hugging Face Hub
|
| 706 |
+
"""
|
| 707 |
+
try:
|
| 708 |
+
logger.info(f"🚀 RLHF training requested with max_datasets={request.max_datasets}")
|
| 709 |
+
|
| 710 |
+
# Initialize trainer
|
| 711 |
+
trainer = RLHFTrainer()
|
| 712 |
+
|
| 713 |
+
# Run training
|
| 714 |
+
result = trainer.train(max_datasets=request.max_datasets)
|
| 715 |
+
|
| 716 |
+
if result["status"] == "success":
|
| 717 |
+
logger.info(f"✅ RLHF training completed: v{result['model_version']}")
|
| 718 |
+
return RLHFTrainingResponse(
|
| 719 |
+
status="success",
|
| 720 |
+
model_version=result["model_version"],
|
| 721 |
+
datasets_processed=result["datasets_processed"],
|
| 722 |
+
samples_processed=result["samples_processed"],
|
| 723 |
+
performance_metrics=result["performance_metrics"],
|
| 724 |
+
timestamp=datetime.now().isoformat()
|
| 725 |
+
)
|
| 726 |
+
elif result["status"] == "no_data":
|
| 727 |
+
logger.info("ℹ️ No new data available for RLHF training")
|
| 728 |
+
return RLHFTrainingResponse(
|
| 729 |
+
status="no_data",
|
| 730 |
+
timestamp=datetime.now().isoformat()
|
| 731 |
+
)
|
| 732 |
+
else:
|
| 733 |
+
logger.error(f"❌ RLHF training failed: {result.get('error', 'Unknown error')}")
|
| 734 |
+
return RLHFTrainingResponse(
|
| 735 |
+
status="error",
|
| 736 |
+
error=result.get("error", "Unknown error"),
|
| 737 |
+
timestamp=datetime.now().isoformat()
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
except Exception as e:
|
| 741 |
+
logger.error(f"❌ RLHF training endpoint failed: {e}")
|
| 742 |
+
raise HTTPException(
|
| 743 |
+
status_code=500,
|
| 744 |
+
detail=f"RLHF training failed: {str(e)}"
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
@app.get("/rlhf/status")
|
| 748 |
+
async def get_rlhf_status():
|
| 749 |
+
"""
|
| 750 |
+
Get status of RLHF training system and available labeled data.
|
| 751 |
+
"""
|
| 752 |
+
try:
|
| 753 |
+
from train import LabeledDataLoader
|
| 754 |
+
|
| 755 |
+
loader = LabeledDataLoader()
|
| 756 |
+
datasets = loader.list_labeled_datasets()
|
| 757 |
+
|
| 758 |
+
return {
|
| 759 |
+
"status": "available",
|
| 760 |
+
"labeled_datasets_count": len(datasets),
|
| 761 |
+
"datasets": [
|
| 762 |
+
{
|
| 763 |
+
"name": d["name"],
|
| 764 |
+
"size": d["size"],
|
| 765 |
+
"created": d["created"]
|
| 766 |
+
} for d in datasets[:10] # Limit to first 10 for response size
|
| 767 |
+
],
|
| 768 |
+
"firebase_bucket": "skyledge-36b56.firebasestorage.app",
|
| 769 |
+
"labeled_path": "skyledge/labeled",
|
| 770 |
+
"timestamp": datetime.now().isoformat()
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
except Exception as e:
|
| 774 |
+
logger.error(f"❌ RLHF status check failed: {e}")
|
| 775 |
+
raise HTTPException(
|
| 776 |
+
status_code=500,
|
| 777 |
+
detail=f"Status check failed: {str(e)}"
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
@app.get("/rlhf/trained-datasets")
|
| 781 |
+
async def get_trained_datasets():
|
| 782 |
+
"""
|
| 783 |
+
Get list of datasets that have already been used for training.
|
| 784 |
+
"""
|
| 785 |
+
try:
|
| 786 |
+
from train import LabeledDataLoader
|
| 787 |
+
|
| 788 |
+
loader = LabeledDataLoader()
|
| 789 |
+
trained_datasets = loader._get_trained_datasets()
|
| 790 |
+
|
| 791 |
+
return {
|
| 792 |
+
"trained_datasets_count": len(trained_datasets),
|
| 793 |
+
"trained_datasets": trained_datasets,
|
| 794 |
+
"timestamp": datetime.now().isoformat()
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
except Exception as e:
|
| 798 |
+
logger.error(f"❌ Failed to get trained datasets: {e}")
|
| 799 |
+
raise HTTPException(
|
| 800 |
+
status_code=500,
|
| 801 |
+
detail=f"Failed to get trained datasets: {str(e)}"
|
| 802 |
+
)
|
data.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"timestamp": "2025-05-15T10:00:00",
|
| 3 |
+
"driving_style": "aggressive",
|
| 4 |
+
"data": {
|
| 5 |
+
"RPM": 3200,
|
| 6 |
+
"THROTTLE_POS": 75,
|
| 7 |
+
"SPEED": 110,
|
| 8 |
+
"FUEL_PRESSURE": 290,
|
| 9 |
+
"ENGINE_LOAD": 45,
|
| 10 |
+
"COOLANT_TEMP": 85,
|
| 11 |
+
"INTAKE_TEMP": 30,
|
| 12 |
+
"TIMING_ADVANCE": 10,
|
| 13 |
+
"MAF": 12.5,
|
| 14 |
+
"INTAKE_PRESSURE": 28,
|
| 15 |
+
"SHORT_FUEL_TRIM_1": 3.1,
|
| 16 |
+
"LONG_FUEL_TRIM_1": 6.2,
|
| 17 |
+
"SHORT_FUEL_TRIM_2": 2.5,
|
| 18 |
+
"LONG_FUEL_TRIM_2": 5.0,
|
| 19 |
+
"COMMANDED_EQUIV_RATIO": 1.0,
|
| 20 |
+
"O2_B1S2": 0.74,
|
| 21 |
+
"O2_B2S2": 0.68,
|
| 22 |
+
"O2_S1_WR_VOLTAGE": 0.85,
|
| 23 |
+
"COMMANDED_EGR": 10
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
data/drive_saver.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Google Drive Operations for OBD Logger
|
| 2 |
+
# Handles authentication and file uploads to Google Drive
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from google.oauth2 import service_account
|
| 8 |
+
from googleapiclient.discovery import build
|
| 9 |
+
from googleapiclient.http import MediaFileUpload
|
| 10 |
+
|
| 11 |
+
# ───────────── Logging Setup ─────────────
|
| 12 |
+
logger = logging.getLogger("drive-saver")
|
| 13 |
+
logger.setLevel(logging.INFO)
|
| 14 |
+
fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
|
| 15 |
+
handler = logging.StreamHandler()
|
| 16 |
+
handler.setFormatter(fmt)
|
| 17 |
+
logger.addHandler(handler)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DriveSaver:
|
| 21 |
+
"""Handles Google Drive operations for saving OBD data"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.service = None
|
| 25 |
+
self.folder_id = "1r-wefqKbK9k9BeYDW1hXRbx4B-0Fvj5P" # Default folder ID
|
| 26 |
+
self._initialize_service()
|
| 27 |
+
|
| 28 |
+
def _initialize_service(self):
|
| 29 |
+
"""Initialize Google Drive service with credentials"""
|
| 30 |
+
try:
|
| 31 |
+
creds_dict = json.loads(os.getenv("GDRIVE_CREDENTIALS_JSON"))
|
| 32 |
+
creds = service_account.Credentials.from_service_account_info(
|
| 33 |
+
creds_dict,
|
| 34 |
+
scopes=["https://www.googleapis.com/auth/drive"]
|
| 35 |
+
)
|
| 36 |
+
self.service = build("drive", "v3", credentials=creds)
|
| 37 |
+
logger.info("✅ Google Drive service initialized successfully")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"❌ Drive initialization failed: {e}")
|
| 40 |
+
self.service = None
|
| 41 |
+
|
| 42 |
+
def upload_csv_to_drive(self, file_path: str, folder_id: str = None) -> bool:
|
| 43 |
+
"""
|
| 44 |
+
Upload a CSV file to Google Drive
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
file_path (str): Path to the CSV file to upload
|
| 48 |
+
folder_id (str, optional): Target folder ID. Uses default if not specified.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
bool: True if upload successful, False otherwise
|
| 52 |
+
"""
|
| 53 |
+
if not self.service:
|
| 54 |
+
logger.error("❌ Drive service not initialized")
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
target_folder = folder_id or self.folder_id
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
file_name = os.path.basename(file_path)
|
| 61 |
+
media = MediaFileUpload(file_path, mimetype='text/csv')
|
| 62 |
+
metadata = {"name": file_name, "parents": [target_folder]}
|
| 63 |
+
|
| 64 |
+
result = self.service.files().create(
|
| 65 |
+
body=metadata,
|
| 66 |
+
media_body=media,
|
| 67 |
+
fields="id"
|
| 68 |
+
).execute()
|
| 69 |
+
|
| 70 |
+
logger.info(f"✅ File uploaded to Drive successfully: {file_name} (ID: {result.get('id')})")
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"❌ Drive upload failed: {e}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
def is_service_available(self) -> bool:
|
| 78 |
+
"""Check if Drive service is available"""
|
| 79 |
+
return self.service is not None
|
| 80 |
+
|
| 81 |
+
def get_folder_id(self) -> str:
|
| 82 |
+
"""Get the default folder ID"""
|
| 83 |
+
return self.folder_id
|
| 84 |
+
|
| 85 |
+
def set_folder_id(self, folder_id: str):
|
| 86 |
+
"""Set a new default folder ID"""
|
| 87 |
+
self.folder_id = folder_id
|
| 88 |
+
logger.info(f"📁 Default folder ID updated to: {folder_id}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Convenience function for backward compatibility
|
| 92 |
+
def get_drive_service():
|
| 93 |
+
"""Legacy function - returns DriveSaver instance"""
|
| 94 |
+
return DriveSaver()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def upload_to_folder(service, file_path, folder_id):
|
| 98 |
+
"""Legacy function - uploads file to specified folder"""
|
| 99 |
+
if isinstance(service, DriveSaver):
|
| 100 |
+
return service.upload_csv_to_drive(file_path, folder_id)
|
| 101 |
+
else:
|
| 102 |
+
# Handle legacy service object
|
| 103 |
+
try:
|
| 104 |
+
file_name = os.path.basename(file_path)
|
| 105 |
+
media = MediaFileUpload(file_path, mimetype='text/csv')
|
| 106 |
+
metadata = {"name": file_name, "parents": [folder_id]}
|
| 107 |
+
return service.files().create(body=metadata, media_body=media, fields="id").execute()
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"❌ Legacy upload failed: {e}")
|
| 110 |
+
return None
|
data/firebase_saver.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# firebase_saver.py
|
| 2 |
+
import os
|
| 3 |
+
import io
|
| 4 |
+
import re
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Optional, Tuple, List
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("firebase-saver")
|
| 13 |
+
logger.setLevel(logging.INFO)
|
| 14 |
+
if not logger.handlers:
|
| 15 |
+
_h = logging.StreamHandler()
|
| 16 |
+
_h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
|
| 17 |
+
logger.addHandler(_h)
|
| 18 |
+
|
| 19 |
+
# ---------- Constants (fixed as requested) ----------
|
| 20 |
+
FIXED_BUCKET = "skyledge-36b56.firebasestorage.app"
|
| 21 |
+
FIXED_PREFIX = "skyledge/processed" # no trailing slash
|
| 22 |
+
|
| 23 |
+
# Pattern: NNN_YYYY-MM-DD_processed.csv
|
| 24 |
+
FILENAME_RE = re.compile(r"^(?P<num>\d{3})_(?P<date>\d{4}-\d{2}-\d{2})_processed\.csv$")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _parse_gs_uri(uri: Optional[str]):
|
| 28 |
+
if not uri or not uri.startswith("gs://"):
|
| 29 |
+
return None, None
|
| 30 |
+
path = uri[len("gs://"):]
|
| 31 |
+
parts = path.split("/", 1)
|
| 32 |
+
bucket = parts[0]
|
| 33 |
+
prefix = parts[1] if len(parts) > 1 else ""
|
| 34 |
+
return bucket, prefix
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _maybe_default_firebase_bucket(name: Optional[str]) -> Optional[str]:
|
| 38 |
+
# If user passed a project ID (no dot), convert to <project>.appspot.com
|
| 39 |
+
if name and "." not in name:
|
| 40 |
+
return f"{name}.appspot.com"
|
| 41 |
+
return name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -------------------- Low-level clients --------------------
|
| 45 |
+
|
| 46 |
+
class _AdminClient:
|
| 47 |
+
"""Firebase Admin SDK storage client."""
|
| 48 |
+
def __init__(self, bucket: str):
|
| 49 |
+
import firebase_admin
|
| 50 |
+
from firebase_admin import credentials, storage as fb_storage
|
| 51 |
+
|
| 52 |
+
raw = os.getenv("FIREBASE_ADMIN_JSON")
|
| 53 |
+
if not raw:
|
| 54 |
+
raise RuntimeError("FIREBASE_ADMIN_JSON not set")
|
| 55 |
+
info = json.loads(raw)
|
| 56 |
+
client_email = info.get("client_email")
|
| 57 |
+
cred = credentials.Certificate(info)
|
| 58 |
+
|
| 59 |
+
if not firebase_admin._apps:
|
| 60 |
+
firebase_admin.initialize_app(cred, {"storageBucket": bucket})
|
| 61 |
+
|
| 62 |
+
# fb_storage.bucket returns a google.cloud.storage.bucket.Bucket
|
| 63 |
+
self.bucket = fb_storage.bucket(bucket)
|
| 64 |
+
self._bucket_name = bucket
|
| 65 |
+
logger.info(f"✅ Firebase Admin initialized | bucket={bucket} as {client_email}")
|
| 66 |
+
|
| 67 |
+
# Uploads
|
| 68 |
+
def upload_from_filename(self, local_path: str, dest_path: str, content_type: str):
|
| 69 |
+
blob = self.bucket.blob(dest_path)
|
| 70 |
+
blob.cache_control = "no-store"
|
| 71 |
+
blob.upload_from_filename(local_path, content_type=content_type)
|
| 72 |
+
|
| 73 |
+
def upload_from_bytes(self, data: bytes, dest_path: str, content_type: str):
|
| 74 |
+
blob = self.bucket.blob(dest_path)
|
| 75 |
+
blob.cache_control = "no-store"
|
| 76 |
+
blob.upload_from_string(data, content_type=content_type)
|
| 77 |
+
|
| 78 |
+
# Listing (needs storage.objects.list permission)
|
| 79 |
+
def list_names(self, prefix: str) -> List[str]:
|
| 80 |
+
# Bucket.list_blobs works via the underlying GCS client
|
| 81 |
+
blobs = self.bucket.list_blobs(prefix=prefix)
|
| 82 |
+
return [b.name for b in blobs]
|
| 83 |
+
|
| 84 |
+
# Existence check (for collision-safe retry)
|
| 85 |
+
def blob_exists(self, path: str) -> bool:
|
| 86 |
+
blob = self.bucket.blob(path)
|
| 87 |
+
return blob.exists()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class _GCSClient:
|
| 91 |
+
"""google-cloud-storage client."""
|
| 92 |
+
def __init__(self, bucket: str):
|
| 93 |
+
from google.cloud import storage
|
| 94 |
+
from google.oauth2 import service_account
|
| 95 |
+
|
| 96 |
+
raw = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON")
|
| 97 |
+
if not raw:
|
| 98 |
+
raise RuntimeError("FIREBASE_SERVICE_ACCOUNT_JSON not set")
|
| 99 |
+
info = json.loads(raw)
|
| 100 |
+
client_email = info.get("client_email")
|
| 101 |
+
creds = service_account.Credentials.from_service_account_info(info)
|
| 102 |
+
project_id = info.get("project_id")
|
| 103 |
+
|
| 104 |
+
self.client = storage.Client(credentials=creds, project=project_id)
|
| 105 |
+
self.bucket = self.client.bucket(bucket)
|
| 106 |
+
self._bucket_name = bucket
|
| 107 |
+
logger.info(f"✅ GCS client initialized | bucket={bucket} as {client_email}")
|
| 108 |
+
|
| 109 |
+
def upload_from_filename(self, local_path: str, dest_path: str, content_type: str):
|
| 110 |
+
blob = self.bucket.blob(dest_path)
|
| 111 |
+
blob.cache_control = "no-store"
|
| 112 |
+
blob.upload_from_filename(local_path, content_type=content_type)
|
| 113 |
+
|
| 114 |
+
def upload_from_bytes(self, data: bytes, dest_path: str, content_type: str):
|
| 115 |
+
blob = self.bucket.blob(dest_path)
|
| 116 |
+
blob.cache_control = "no-store"
|
| 117 |
+
blob.upload_from_string(data, content_type=content_type)
|
| 118 |
+
|
| 119 |
+
def list_names(self, prefix: str) -> List[str]:
|
| 120 |
+
blobs = self.client.list_blobs(self._bucket_name, prefix=prefix)
|
| 121 |
+
return [b.name for b in blobs]
|
| 122 |
+
|
| 123 |
+
def blob_exists(self, path: str) -> bool:
|
| 124 |
+
blob = self.bucket.blob(path)
|
| 125 |
+
return blob.exists(self.client)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# -------------------- Saver (high level) --------------------
|
| 129 |
+
|
| 130 |
+
class FirebaseSaver:
|
| 131 |
+
"""
|
| 132 |
+
Fixed target:
|
| 133 |
+
Bucket: skyledge-36b56.firebasestorage.app
|
| 134 |
+
Prefix: skyledge/processed
|
| 135 |
+
Filename convention: NNN_YYYY-MM-DD_processed.csv (NNN is 001-based, zero-padded).
|
| 136 |
+
Auto-increments by listing current objects and picking max+1.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self):
|
| 140 |
+
# Force fixed location regardless of env (as requested)
|
| 141 |
+
bucket_name = FIXED_BUCKET
|
| 142 |
+
self.prefix = FIXED_PREFIX
|
| 143 |
+
|
| 144 |
+
# Try Admin SDK first; fallback to GCS client
|
| 145 |
+
self.client = None
|
| 146 |
+
self.mode = None
|
| 147 |
+
try:
|
| 148 |
+
if os.getenv("FIREBASE_ADMIN_JSON"):
|
| 149 |
+
self.client = _AdminClient(bucket_name)
|
| 150 |
+
self.mode = "admin"
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.warning(f"⚠️ Admin SDK init failed: {e}")
|
| 153 |
+
|
| 154 |
+
if self.client is None:
|
| 155 |
+
try:
|
| 156 |
+
self.client = _GCSClient(bucket_name)
|
| 157 |
+
self.mode = "gcs"
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"❌ GCS client init failed: {e}")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
logger.info(f"📦 FirebaseSaver ready | mode={self.mode} bucket={bucket_name} prefix={self.prefix}")
|
| 163 |
+
|
| 164 |
+
def is_available(self) -> bool:
|
| 165 |
+
return self.client is not None
|
| 166 |
+
|
| 167 |
+
# ---------- Incremental naming helpers ----------
|
| 168 |
+
|
| 169 |
+
def _list_existing_filenames(self) -> List[str]:
|
| 170 |
+
"""List object names under the fixed prefix, return just basenames under that folder."""
|
| 171 |
+
names = self.client.list_names(prefix=self.prefix + "/")
|
| 172 |
+
# keep only items immediately under prefix (not subfolders) & matching our filename pattern
|
| 173 |
+
base_names = []
|
| 174 |
+
for full in names:
|
| 175 |
+
# full looks like 'skyledge/processed/NNN_YYYY-MM-DD_processed.csv'
|
| 176 |
+
if not full.startswith(self.prefix + "/"):
|
| 177 |
+
continue
|
| 178 |
+
base = full[len(self.prefix) + 1:] # strip 'prefix/'
|
| 179 |
+
if "/" in base:
|
| 180 |
+
# skip nested items (none expected)
|
| 181 |
+
continue
|
| 182 |
+
if FILENAME_RE.match(base):
|
| 183 |
+
base_names.append(base)
|
| 184 |
+
return base_names
|
| 185 |
+
|
| 186 |
+
def _max_existing_id(self) -> int:
|
| 187 |
+
"""Return max NNN found under prefix, or 0 if none."""
|
| 188 |
+
try:
|
| 189 |
+
base_names = self._list_existing_filenames()
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.warning(f"⚠️ Unable to list existing objects; defaulting max_id=0: {e}")
|
| 192 |
+
return 0
|
| 193 |
+
|
| 194 |
+
max_id = 0
|
| 195 |
+
for name in base_names:
|
| 196 |
+
m = FILENAME_RE.match(name)
|
| 197 |
+
if not m:
|
| 198 |
+
continue
|
| 199 |
+
try:
|
| 200 |
+
num = int(m.group("num"))
|
| 201 |
+
if num > max_id:
|
| 202 |
+
max_id = num
|
| 203 |
+
except ValueError:
|
| 204 |
+
continue
|
| 205 |
+
return max_id
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def _format_id(n: int) -> str:
|
| 209 |
+
return f"{n:03d}"
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def _today_au() -> str:
|
| 213 |
+
# Use Australia/Melbourne local date; if zoneinfo unavailable, fall back to UTC date.
|
| 214 |
+
try:
|
| 215 |
+
from zoneinfo import ZoneInfo
|
| 216 |
+
dt = datetime.now(ZoneInfo("Australia/Melbourne"))
|
| 217 |
+
except Exception:
|
| 218 |
+
dt = datetime.utcnow()
|
| 219 |
+
return dt.strftime("%Y-%m-%d")
|
| 220 |
+
|
| 221 |
+
def _build_filename(self, n_int: int, date_str: Optional[str] = None) -> str:
|
| 222 |
+
date_val = (date_str or self._today_au())
|
| 223 |
+
return f"{self._format_id(n_int)}_{date_val}_processed.csv"
|
| 224 |
+
|
| 225 |
+
def _dest_path(self, filename: str) -> str:
|
| 226 |
+
return f"{self.prefix}/{filename}"
|
| 227 |
+
|
| 228 |
+
def _next_available_name(self, date_str: Optional[str] = None, max_retries: int = 5) -> Tuple[str, str]:
|
| 229 |
+
"""
|
| 230 |
+
Compute the next file name by listing existing ones and incrementing.
|
| 231 |
+
Includes a collision check (exists) and retries if necessary.
|
| 232 |
+
Returns: (filename, full_gcs_path)
|
| 233 |
+
"""
|
| 234 |
+
start = self._max_existing_id() + 1
|
| 235 |
+
n = start
|
| 236 |
+
for _ in range(max_retries):
|
| 237 |
+
candidate = self._build_filename(n, date_str=date_str)
|
| 238 |
+
dest_path = self._dest_path(candidate)
|
| 239 |
+
# collision check
|
| 240 |
+
if not self.client.blob_exists(dest_path):
|
| 241 |
+
return candidate, dest_path
|
| 242 |
+
n += 1
|
| 243 |
+
|
| 244 |
+
# As a final fallback, return the last tried (very unlikely to collide repeatedly)
|
| 245 |
+
candidate = self._build_filename(n, date_str=date_str)
|
| 246 |
+
return candidate, self._dest_path(candidate)
|
| 247 |
+
|
| 248 |
+
# ---------- Public save methods (incremental) ----------
|
| 249 |
+
|
| 250 |
+
def upload_file_with_increment(
|
| 251 |
+
self,
|
| 252 |
+
local_path: str,
|
| 253 |
+
date_str: Optional[str] = None,
|
| 254 |
+
content_type: str = "text/csv",
|
| 255 |
+
) -> str:
|
| 256 |
+
"""
|
| 257 |
+
Upload a local file using the next incremental name.
|
| 258 |
+
Returns the gs:// URL of the uploaded object (string) or "" on failure.
|
| 259 |
+
"""
|
| 260 |
+
if not self.is_available():
|
| 261 |
+
logger.warning("⚠️ Firebase saver unavailable")
|
| 262 |
+
return ""
|
| 263 |
+
try:
|
| 264 |
+
filename, dest_path = self._next_available_name(date_str=date_str)
|
| 265 |
+
self.client.upload_from_filename(local_path, dest_path, content_type)
|
| 266 |
+
logger.info(f"✅ Uploaded file to gs://{FIXED_BUCKET}/{dest_path}")
|
| 267 |
+
return f"gs://{FIXED_BUCKET}/{dest_path}"
|
| 268 |
+
except Exception as e:
|
| 269 |
+
logger.error(f"❌ Firebase upload failed: {e}")
|
| 270 |
+
return ""
|
| 271 |
+
|
| 272 |
+
def upload_dataframe_with_increment(
|
| 273 |
+
self,
|
| 274 |
+
df: pd.DataFrame,
|
| 275 |
+
date_str: Optional[str] = None,
|
| 276 |
+
content_type: str = "text/csv",
|
| 277 |
+
) -> str:
|
| 278 |
+
"""
|
| 279 |
+
Upload a DataFrame (as CSV) using the next incremental name.
|
| 280 |
+
Returns the gs:// URL of the uploaded object (string) or "" on failure.
|
| 281 |
+
"""
|
| 282 |
+
if not self.is_available():
|
| 283 |
+
logger.warning("⚠️ Firebase saver unavailable")
|
| 284 |
+
return ""
|
| 285 |
+
try:
|
| 286 |
+
buf = io.StringIO()
|
| 287 |
+
df.to_csv(buf, index=False)
|
| 288 |
+
data = buf.getvalue().encode("utf-8")
|
| 289 |
+
|
| 290 |
+
filename, dest_path = self._next_available_name(date_str=date_str)
|
| 291 |
+
self.client.upload_from_bytes(data, dest_path, content_type)
|
| 292 |
+
logger.info(f"✅ Uploaded DataFrame to gs://{FIXED_BUCKET}/{dest_path}")
|
| 293 |
+
return f"gs://{FIXED_BUCKET}/{dest_path}"
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"❌ Firebase DF upload failed: {e}")
|
| 296 |
+
return ""
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ---------- Convenience free functions ----------
|
| 300 |
+
|
| 301 |
+
def save_csv_increment(csv_path: str, date_str: Optional[str] = None) -> str:
|
| 302 |
+
"""
|
| 303 |
+
Upload local CSV with auto-incremented name 'NNN_YYYY-MM-DD_processed.csv'.
|
| 304 |
+
Returns gs:// URL or "".
|
| 305 |
+
"""
|
| 306 |
+
saver = FirebaseSaver()
|
| 307 |
+
return saver.upload_file_with_increment(csv_path, date_str=date_str)
|
| 308 |
+
|
| 309 |
+
def save_dataframe_increment(df: pd.DataFrame, date_str: Optional[str] = None) -> str:
|
| 310 |
+
"""
|
| 311 |
+
Upload DataFrame with auto-incremented name 'NNN_YYYY-MM-DD_processed.csv'.
|
| 312 |
+
Returns gs:// URL or "".
|
| 313 |
+
"""
|
| 314 |
+
saver = FirebaseSaver()
|
| 315 |
+
return saver.upload_dataframe_with_increment(df, date_str=date_str)
|
data/mongo_saver.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MongoDB Operations for OBD Logger
|
| 2 |
+
# Handles data restructuring and saving to MongoDB Atlas
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Dict, List, Any, Optional
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
# MongoDB dependencies
|
| 13 |
+
try:
|
| 14 |
+
from pymongo import MongoClient
|
| 15 |
+
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
| 16 |
+
MONGODB_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
MONGODB_AVAILABLE = False
|
| 19 |
+
print("⚠️ PyMongo not available. Install with: pip install pymongo")
|
| 20 |
+
|
| 21 |
+
# ───────────── Logging Setup ─────────────
|
| 22 |
+
logger = logging.getLogger("mongo-saver")
|
| 23 |
+
logger.setLevel(logging.INFO)
|
| 24 |
+
fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
|
| 25 |
+
handler = logging.StreamHandler()
|
| 26 |
+
handler.setFormatter(fmt)
|
| 27 |
+
logger.addHandler(handler)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MongoSaver:
|
| 31 |
+
"""Handles MongoDB operations for saving OBD data"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, mongo_uri: str = None):
|
| 34 |
+
self.client = None
|
| 35 |
+
self.db = None
|
| 36 |
+
self.collection = None
|
| 37 |
+
self.mongo_uri = mongo_uri or os.getenv("MONGO_URI")
|
| 38 |
+
self._initialize_connection()
|
| 39 |
+
|
| 40 |
+
def _initialize_connection(self):
|
| 41 |
+
"""Initialize MongoDB connection"""
|
| 42 |
+
if not MONGODB_AVAILABLE:
|
| 43 |
+
logger.error("❌ PyMongo not available. Cannot connect to MongoDB")
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
if not self.mongo_uri:
|
| 47 |
+
logger.error("❌ MongoDB URI not provided. Set MONGO_URI environment variable")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
# Connect with timeout and retry settings
|
| 52 |
+
self.client = MongoClient(
|
| 53 |
+
self.mongo_uri,
|
| 54 |
+
serverSelectionTimeoutMS=5000, # 5 second timeout
|
| 55 |
+
connectTimeoutMS=10000, # 10 second connection timeout
|
| 56 |
+
socketTimeoutMS=10000 # 10 second socket timeout
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Test connection
|
| 60 |
+
self.client.admin.command('ping')
|
| 61 |
+
|
| 62 |
+
# Set up database and collection
|
| 63 |
+
self.db = self.client.obd_logger
|
| 64 |
+
self.collection = self.db.obd_data
|
| 65 |
+
|
| 66 |
+
# Create indexes for better performance
|
| 67 |
+
self._create_indexes()
|
| 68 |
+
|
| 69 |
+
logger.info("✅ MongoDB connection established successfully")
|
| 70 |
+
|
| 71 |
+
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
|
| 72 |
+
logger.error(f"❌ MongoDB connection failed: {e}")
|
| 73 |
+
self.client = None
|
| 74 |
+
self.db = None
|
| 75 |
+
self.collection = None
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"❌ MongoDB initialization error: {e}")
|
| 78 |
+
self.client = None
|
| 79 |
+
self.db = None
|
| 80 |
+
self.collection = None
|
| 81 |
+
|
| 82 |
+
def _create_indexes(self):
|
| 83 |
+
"""Create database indexes for better query performance"""
|
| 84 |
+
try:
|
| 85 |
+
# Index on timestamp for time-based queries
|
| 86 |
+
self.collection.create_index("timestamp")
|
| 87 |
+
|
| 88 |
+
# Index on driving_style for filtering
|
| 89 |
+
self.collection.create_index("driving_style")
|
| 90 |
+
|
| 91 |
+
# Compound index for common queries
|
| 92 |
+
self.collection.create_index([("timestamp", -1), ("driving_style", 1)])
|
| 93 |
+
|
| 94 |
+
# Index on session_id for session-based queries
|
| 95 |
+
self.collection.create_index("session_id")
|
| 96 |
+
|
| 97 |
+
logger.info("✅ Database indexes created successfully")
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.warning(f"⚠️ Index creation failed: {e}")
|
| 101 |
+
|
| 102 |
+
def is_connected(self) -> bool:
|
| 103 |
+
"""Check if MongoDB connection is active"""
|
| 104 |
+
if not self.client:
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
# Ping the database
|
| 109 |
+
self.client.admin.command('ping')
|
| 110 |
+
return True
|
| 111 |
+
except Exception:
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
def save_csv_to_mongo(self, csv_file_path: str, session_id: str = None) -> bool:
|
| 115 |
+
"""
|
| 116 |
+
Read CSV file and save data to MongoDB
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
csv_file_path (str): Path to the CSV file
|
| 120 |
+
session_id (str, optional): Unique identifier for this data session
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
bool: True if save successful, False otherwise
|
| 124 |
+
"""
|
| 125 |
+
if not self.is_connected():
|
| 126 |
+
logger.error("❌ MongoDB not connected")
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
# Read CSV file
|
| 131 |
+
df = pd.read_csv(csv_file_path)
|
| 132 |
+
|
| 133 |
+
if df.empty:
|
| 134 |
+
logger.warning("⚠️ CSV file is empty")
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
# Generate session ID if not provided
|
| 138 |
+
if not session_id:
|
| 139 |
+
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 140 |
+
|
| 141 |
+
# Convert DataFrame to MongoDB documents
|
| 142 |
+
documents = self._dataframe_to_documents(df, session_id)
|
| 143 |
+
|
| 144 |
+
# Insert documents into MongoDB
|
| 145 |
+
if documents:
|
| 146 |
+
result = self.collection.insert_many(documents)
|
| 147 |
+
logger.info(f"✅ Saved {len(result.inserted_ids)} records to MongoDB (Session: {session_id})")
|
| 148 |
+
return True
|
| 149 |
+
else:
|
| 150 |
+
logger.warning("⚠️ No valid documents to save")
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.error(f"❌ Failed to save CSV to MongoDB: {e}")
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
def save_dataframe_to_mongo(self, df: pd.DataFrame, session_id: str = None) -> bool:
|
| 158 |
+
"""
|
| 159 |
+
Save pandas DataFrame directly to MongoDB
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
df (pd.DataFrame): DataFrame to save
|
| 163 |
+
session_id (str, optional): Unique identifier for this data session
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
bool: True if save successful, False otherwise
|
| 167 |
+
"""
|
| 168 |
+
if not self.is_connected():
|
| 169 |
+
logger.error("❌ MongoDB not connected")
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
if df.empty:
|
| 174 |
+
logger.warning("⚠️ DataFrame is empty")
|
| 175 |
+
return False
|
| 176 |
+
|
| 177 |
+
# Generate session ID if not provided
|
| 178 |
+
if not session_id:
|
| 179 |
+
session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 180 |
+
|
| 181 |
+
# Convert DataFrame to MongoDB documents
|
| 182 |
+
documents = self._dataframe_to_documents(df, session_id)
|
| 183 |
+
|
| 184 |
+
# Insert documents into MongoDB
|
| 185 |
+
if documents:
|
| 186 |
+
result = self.collection.insert_many(documents)
|
| 187 |
+
logger.info(f"✅ Saved {len(result.inserted_ids)} records to MongoDB (Session: {session_id})")
|
| 188 |
+
return True
|
| 189 |
+
else:
|
| 190 |
+
logger.warning("⚠️ No valid documents to save")
|
| 191 |
+
return False
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"❌ Failed to save DataFrame to MongoDB: {e}")
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
def _dataframe_to_documents(self, df: pd.DataFrame, session_id: str) -> List[Dict[str, Any]]:
|
| 198 |
+
"""
|
| 199 |
+
Convert pandas DataFrame to MongoDB documents
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
df (pd.DataFrame): Input DataFrame
|
| 203 |
+
session_id (str): Session identifier
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
List[Dict[str, Any]]: List of MongoDB documents
|
| 207 |
+
"""
|
| 208 |
+
documents = []
|
| 209 |
+
|
| 210 |
+
for index, row in df.iterrows():
|
| 211 |
+
try:
|
| 212 |
+
# Convert row to dictionary
|
| 213 |
+
doc = row.to_dict()
|
| 214 |
+
|
| 215 |
+
# Add metadata
|
| 216 |
+
doc['session_id'] = session_id
|
| 217 |
+
doc['imported_at'] = datetime.utcnow()
|
| 218 |
+
doc['record_index'] = index
|
| 219 |
+
|
| 220 |
+
# Handle timestamp conversion
|
| 221 |
+
if 'timestamp' in doc and pd.notna(doc['timestamp']):
|
| 222 |
+
try:
|
| 223 |
+
# Try to parse timestamp
|
| 224 |
+
if isinstance(doc['timestamp'], str):
|
| 225 |
+
doc['timestamp'] = pd.to_datetime(doc['timestamp'])
|
| 226 |
+
# Convert to datetime object
|
| 227 |
+
doc['timestamp'] = doc['timestamp'].to_pydatetime()
|
| 228 |
+
except Exception:
|
| 229 |
+
# Keep as string if parsing fails
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
# Convert numeric types and handle NaN values
|
| 233 |
+
for key, value in doc.items():
|
| 234 |
+
if pd.isna(value):
|
| 235 |
+
doc[key] = None
|
| 236 |
+
elif isinstance(value, (np.integer, np.floating)):
|
| 237 |
+
doc[key] = value.item()
|
| 238 |
+
elif isinstance(value, np.bool_):
|
| 239 |
+
doc[key] = bool(value)
|
| 240 |
+
|
| 241 |
+
documents.append(doc)
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.warning(f"⚠️ Failed to process row {index}: {e}")
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
return documents
|
| 248 |
+
|
| 249 |
+
def query_data(self,
|
| 250 |
+
session_id: str = None,
|
| 251 |
+
driving_style: str = None,
|
| 252 |
+
start_time: datetime = None,
|
| 253 |
+
end_time: datetime = None,
|
| 254 |
+
limit: int = 1000) -> List[Dict[str, Any]]:
|
| 255 |
+
"""
|
| 256 |
+
Query data from MongoDB
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
session_id (str, optional): Filter by session ID
|
| 260 |
+
driving_style (str, optional): Filter by driving style
|
| 261 |
+
start_time (datetime, optional): Start time filter
|
| 262 |
+
end_time (datetime, optional): End time filter
|
| 263 |
+
limit (int): Maximum number of records to return
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
List[Dict[str, Any]]: Query results
|
| 267 |
+
"""
|
| 268 |
+
if not self.is_connected():
|
| 269 |
+
logger.error("❌ MongoDB not connected")
|
| 270 |
+
return []
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
# Build query filter
|
| 274 |
+
query_filter = {}
|
| 275 |
+
|
| 276 |
+
if session_id:
|
| 277 |
+
query_filter['session_id'] = session_id
|
| 278 |
+
|
| 279 |
+
if driving_style:
|
| 280 |
+
query_filter['driving_style'] = driving_style
|
| 281 |
+
|
| 282 |
+
if start_time or end_time:
|
| 283 |
+
time_filter = {}
|
| 284 |
+
if start_time:
|
| 285 |
+
time_filter['$gte'] = start_time
|
| 286 |
+
if end_time:
|
| 287 |
+
time_filter['$lte'] = end_time
|
| 288 |
+
query_filter['timestamp'] = time_filter
|
| 289 |
+
|
| 290 |
+
# Execute query
|
| 291 |
+
cursor = self.collection.find(query_filter).limit(limit)
|
| 292 |
+
results = list(cursor)
|
| 293 |
+
|
| 294 |
+
logger.info(f"✅ Query returned {len(results)} records")
|
| 295 |
+
return results
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
logger.error(f"❌ Query failed: {e}")
|
| 299 |
+
return []
|
| 300 |
+
|
| 301 |
+
def get_session_summary(self) -> List[Dict[str, Any]]:
|
| 302 |
+
"""
|
| 303 |
+
Get summary of all data sessions
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
List[Dict[str, Any]]: Session summaries
|
| 307 |
+
"""
|
| 308 |
+
if not self.is_connected():
|
| 309 |
+
logger.error("❌ MongoDB not connected")
|
| 310 |
+
return []
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
pipeline = [
|
| 314 |
+
{
|
| 315 |
+
'$group': {
|
| 316 |
+
'_id': '$session_id',
|
| 317 |
+
'count': {'$sum': 1},
|
| 318 |
+
'driving_styles': {'$addToSet': '$driving_style'},
|
| 319 |
+
'first_record': {'$min': '$timestamp'},
|
| 320 |
+
'last_record': {'$max': '$timestamp'},
|
| 321 |
+
'imported_at': {'$first': '$imported_at'}
|
| 322 |
+
}
|
| 323 |
+
},
|
| 324 |
+
{
|
| 325 |
+
'$sort': {'imported_at': -1}
|
| 326 |
+
}
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
results = list(self.collection.aggregate(pipeline))
|
| 330 |
+
logger.info(f"✅ Retrieved summary for {len(results)} sessions")
|
| 331 |
+
return results
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.error(f"❌ Session summary failed: {e}")
|
| 335 |
+
return []
|
| 336 |
+
|
| 337 |
+
def close_connection(self):
|
| 338 |
+
"""Close MongoDB connection"""
|
| 339 |
+
if self.client:
|
| 340 |
+
self.client.close()
|
| 341 |
+
logger.info("✅ MongoDB connection closed")
|
| 342 |
+
|
| 343 |
+
def __enter__(self):
|
| 344 |
+
"""Context manager entry"""
|
| 345 |
+
return self
|
| 346 |
+
|
| 347 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 348 |
+
"""Context manager exit"""
|
| 349 |
+
self.close_connection()
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# Convenience functions
|
| 353 |
+
def save_csv_to_mongo(csv_file_path: str, session_id: str = None) -> bool:
|
| 354 |
+
"""Convenience function to save CSV to MongoDB"""
|
| 355 |
+
with MongoSaver() as saver:
|
| 356 |
+
return saver.save_csv_to_mongo(csv_file_path, session_id)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def save_dataframe_to_mongo(df: pd.DataFrame, session_id: str = None) -> bool:
|
| 360 |
+
"""Convenience function to save DataFrame to MongoDB"""
|
| 361 |
+
with MongoSaver() as saver:
|
| 362 |
+
return saver.save_dataframe_to_mongo(df, session_id)
|
organization.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to reorganize existing models in HF repo to versioned structure.
|
| 4 |
+
This will move the current 3 .pkl files from root to v1.0 folder.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Load environment variables from .env file
|
| 14 |
+
def load_env():
|
| 15 |
+
"""Load environment variables from .env file"""
|
| 16 |
+
env_path = Path(__file__).parent / '.env'
|
| 17 |
+
if env_path.exists():
|
| 18 |
+
with open(env_path, 'r') as f:
|
| 19 |
+
for line in f:
|
| 20 |
+
line = line.strip()
|
| 21 |
+
if line and not line.startswith('#') and '=' in line:
|
| 22 |
+
key, value = line.split('=', 1)
|
| 23 |
+
os.environ[key] = value
|
| 24 |
+
print(f"✅ Loaded environment variables from {env_path}")
|
| 25 |
+
else:
|
| 26 |
+
print("⚠️ No .env file found")
|
| 27 |
+
|
| 28 |
+
# Load environment variables
|
| 29 |
+
load_env()
|
| 30 |
+
|
| 31 |
+
# Add train directory to path
|
| 32 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'train'))
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
"""Main function to reorganize models"""
|
| 36 |
+
print("🔄 Reorganizing models in Hugging Face repository...")
|
| 37 |
+
print("=" * 60)
|
| 38 |
+
|
| 39 |
+
# Check if HF_TOKEN is set
|
| 40 |
+
if not os.getenv("HF_TOKEN"):
|
| 41 |
+
print("❌ Error: HF_TOKEN environment variable not set")
|
| 42 |
+
print("Please set your Hugging Face token:")
|
| 43 |
+
print("export HF_TOKEN=your_token_here")
|
| 44 |
+
return 1
|
| 45 |
+
|
| 46 |
+
# Check if we're in the right directory
|
| 47 |
+
if not os.path.exists("train/rlhf.py"):
|
| 48 |
+
print("❌ Error: Please run this script from the OBD_Logger root directory")
|
| 49 |
+
return 1
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
# Import and run the reorganization
|
| 53 |
+
from train.move_models_to_v1 import move_models_to_v1
|
| 54 |
+
|
| 55 |
+
print("📥 Starting model reorganization...")
|
| 56 |
+
move_models_to_v1()
|
| 57 |
+
|
| 58 |
+
print("\n✅ Model reorganization completed successfully!")
|
| 59 |
+
print("📁 Your models are now organized in the v1.0 folder")
|
| 60 |
+
print("🔄 Future RLHF training will create v1.1, v1.2, etc.")
|
| 61 |
+
print("\nNext steps:")
|
| 62 |
+
print("1. Verify the models are in the v1.0 folder on Hugging Face")
|
| 63 |
+
print("2. Test the RLHF training with: curl -X POST 'http://localhost:8000/rlhf/train'")
|
| 64 |
+
|
| 65 |
+
return 0
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"❌ Reorganization failed: {e}")
|
| 69 |
+
print("\nTroubleshooting:")
|
| 70 |
+
print("1. Make sure HF_TOKEN is set correctly")
|
| 71 |
+
print("2. Check that you have write access to the repository")
|
| 72 |
+
print("3. Verify the repository name is correct")
|
| 73 |
+
return 1
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
exit(main())
|
organze.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple script to reorganize existing models in HF repo to versioned structure.
|
| 4 |
+
This will move the current 3 .pkl files from root to v1.0 folder.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from huggingface_hub import HfApi, hf_hub_download, upload_file
|
| 12 |
+
|
| 13 |
+
def load_env():
|
| 14 |
+
"""Load environment variables from .env file"""
|
| 15 |
+
env_path = Path(__file__).parent / '.env'
|
| 16 |
+
if env_path.exists():
|
| 17 |
+
with open(env_path, 'r') as f:
|
| 18 |
+
for line in f:
|
| 19 |
+
line = line.strip()
|
| 20 |
+
if line and not line.startswith('#') and '=' in line:
|
| 21 |
+
key, value = line.split('=', 1)
|
| 22 |
+
os.environ[key] = value
|
| 23 |
+
print(f"✅ Loaded environment variables from {env_path}")
|
| 24 |
+
else:
|
| 25 |
+
print("⚠️ No .env file found")
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
"""Main function to reorganize models"""
|
| 29 |
+
print("🔄 Reorganizing models in Hugging Face repository...")
|
| 30 |
+
print("=" * 60)
|
| 31 |
+
|
| 32 |
+
# Load environment variables
|
| 33 |
+
load_env()
|
| 34 |
+
|
| 35 |
+
# Check if HF_TOKEN is set
|
| 36 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 37 |
+
if not hf_token:
|
| 38 |
+
print("❌ Error: HF_TOKEN not found in environment")
|
| 39 |
+
return 1
|
| 40 |
+
|
| 41 |
+
print(f"✅ HF_TOKEN loaded: {hf_token[:10]}...")
|
| 42 |
+
|
| 43 |
+
# Configuration
|
| 44 |
+
repo_id = "BinKhoaLe1812/Driver_Behavior_OBD"
|
| 45 |
+
model_files = ["label_encoder_ul.pkl", "scaler_ul.pkl", "xgb_drivestyle_ul.pkl"]
|
| 46 |
+
|
| 47 |
+
print(f"📦 Target repository: {repo_id}")
|
| 48 |
+
print(f"📁 Model files to move: {model_files}")
|
| 49 |
+
|
| 50 |
+
# Initialize HF API
|
| 51 |
+
hf_api = HfApi(token=hf_token)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Create temporary directory
|
| 55 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 56 |
+
temp_path = Path(temp_dir)
|
| 57 |
+
print(f"📁 Using temporary directory: {temp_path}")
|
| 58 |
+
|
| 59 |
+
# Download existing model files
|
| 60 |
+
downloaded_files = []
|
| 61 |
+
for file in model_files:
|
| 62 |
+
try:
|
| 63 |
+
print(f"📥 Downloading {file}...")
|
| 64 |
+
local_path = hf_hub_download(
|
| 65 |
+
repo_id=repo_id,
|
| 66 |
+
filename=file,
|
| 67 |
+
repo_type="model",
|
| 68 |
+
token=hf_token
|
| 69 |
+
)
|
| 70 |
+
downloaded_files.append((file, local_path))
|
| 71 |
+
print(f"✅ Downloaded {file}")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"⚠️ Could not download {file}: {e}")
|
| 74 |
+
|
| 75 |
+
if not downloaded_files:
|
| 76 |
+
print("⚠️ No model files found to move")
|
| 77 |
+
return 1
|
| 78 |
+
|
| 79 |
+
# Create v1.0 directory structure
|
| 80 |
+
v1_dir = temp_path / "v1.0"
|
| 81 |
+
v1_dir.mkdir(exist_ok=True)
|
| 82 |
+
print(f"📁 Created v1.0 directory: {v1_dir}")
|
| 83 |
+
|
| 84 |
+
# Copy files to v1.0 directory
|
| 85 |
+
for filename, local_path in downloaded_files:
|
| 86 |
+
dest_path = v1_dir / filename
|
| 87 |
+
import shutil
|
| 88 |
+
shutil.copy2(local_path, dest_path)
|
| 89 |
+
print(f"📦 Prepared {filename} for v1.0/")
|
| 90 |
+
|
| 91 |
+
# Create metadata.json for v1.0
|
| 92 |
+
metadata = {
|
| 93 |
+
"version": "1.0",
|
| 94 |
+
"model_type": "xgboost_classifier",
|
| 95 |
+
"created_at": "2024-12-01T00:00:00",
|
| 96 |
+
"description": "Initial model version - moved from root directory",
|
| 97 |
+
"framework": "xgboost",
|
| 98 |
+
"task": "driver_behavior_classification",
|
| 99 |
+
"labels": ["aggressive", "normal", "conservative"],
|
| 100 |
+
"features": "obd_sensor_data",
|
| 101 |
+
"files": [f[0] for f in downloaded_files]
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
metadata_path = v1_dir / "metadata.json"
|
| 105 |
+
with open(metadata_path, 'w') as f:
|
| 106 |
+
json.dump(metadata, f, indent=2)
|
| 107 |
+
print("📝 Created metadata.json for v1.0")
|
| 108 |
+
|
| 109 |
+
# Create README.md for v1.0
|
| 110 |
+
readme_content = """---
|
| 111 |
+
license: mit
|
| 112 |
+
tags:
|
| 113 |
+
- driver-behavior
|
| 114 |
+
- obd-data
|
| 115 |
+
- xgboost
|
| 116 |
+
- version-1.0
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
# Driver Behavior Classification Model v1.0
|
| 120 |
+
|
| 121 |
+
Initial version of the driver behavior classification model.
|
| 122 |
+
|
| 123 |
+
## Files
|
| 124 |
+
|
| 125 |
+
- `xgb_drivestyle_ul.pkl`: Main XGBoost model
|
| 126 |
+
- `label_encoder_ul.pkl`: Label encoder for behavior categories
|
| 127 |
+
- `scaler_ul.pkl`: Feature scaler
|
| 128 |
+
- `metadata.json`: Model metadata
|
| 129 |
+
|
| 130 |
+
## Usage
|
| 131 |
+
|
| 132 |
+
```python
|
| 133 |
+
import joblib
|
| 134 |
+
|
| 135 |
+
# Load the model
|
| 136 |
+
model = joblib.load('xgb_drivestyle_ul.pkl')
|
| 137 |
+
label_encoder = joblib.load('label_encoder_ul.pkl')
|
| 138 |
+
scaler = joblib.load('scaler_ul.pkl')
|
| 139 |
+
|
| 140 |
+
# Make predictions
|
| 141 |
+
predictions = model.predict(scaled_data)
|
| 142 |
+
behavior_labels = label_encoder.inverse_transform(predictions)
|
| 143 |
+
```
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
readme_path = v1_dir / "README.md"
|
| 147 |
+
with open(readme_path, 'w') as f:
|
| 148 |
+
f.write(readme_content)
|
| 149 |
+
print("📖 Created README.md for v1.0")
|
| 150 |
+
|
| 151 |
+
# Upload files to v1.0 directory in HF repo
|
| 152 |
+
print("🚀 Uploading files to Hugging Face Hub...")
|
| 153 |
+
for file_path in v1_dir.iterdir():
|
| 154 |
+
if file_path.is_file():
|
| 155 |
+
hf_filename = f"v1.0/{file_path.name}"
|
| 156 |
+
print(f"📤 Uploading {file_path.name} to {hf_filename}...")
|
| 157 |
+
upload_file(
|
| 158 |
+
path_or_fileobj=str(file_path),
|
| 159 |
+
path_in_repo=hf_filename,
|
| 160 |
+
repo_id=repo_id,
|
| 161 |
+
repo_type="model",
|
| 162 |
+
token=hf_token,
|
| 163 |
+
commit_message=f"Add {file_path.name} to v1.0 directory"
|
| 164 |
+
)
|
| 165 |
+
print(f"✅ Uploaded {file_path.name} to v1.0/")
|
| 166 |
+
|
| 167 |
+
print("\n✅ Successfully moved models to v1.0 structure!")
|
| 168 |
+
print(f"📁 Models now located at: {repo_id}/v1.0/")
|
| 169 |
+
print("\nNext steps:")
|
| 170 |
+
print("1. Verify the models are in the v1.0 folder on Hugging Face")
|
| 171 |
+
print("2. Test the RLHF training with: curl -X POST 'http://localhost:8000/rlhf/train'")
|
| 172 |
+
|
| 173 |
+
return 0
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"❌ Reorganization failed: {e}")
|
| 177 |
+
print("\nTroubleshooting:")
|
| 178 |
+
print("1. Make sure HF_TOKEN is set correctly")
|
| 179 |
+
print("2. Check that you have write access to the repository")
|
| 180 |
+
print("3. Verify the repository name is correct")
|
| 181 |
+
return 1
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
exit(main())
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Server
|
| 2 |
+
fastapi
|
| 3 |
+
uvicorn[standard]
|
| 4 |
+
python-multipart
|
| 5 |
+
jinja2
|
| 6 |
+
|
| 7 |
+
# Data
|
| 8 |
+
pandas
|
| 9 |
+
numpy
|
| 10 |
+
scikit-learn
|
| 11 |
+
|
| 12 |
+
# ML Models
|
| 13 |
+
xgboost
|
| 14 |
+
joblib
|
| 15 |
+
|
| 16 |
+
# Drive
|
| 17 |
+
gspread
|
| 18 |
+
oauth2client
|
| 19 |
+
google-auth
|
| 20 |
+
google-auth-httplib2
|
| 21 |
+
google-auth-oauthlib
|
| 22 |
+
google-api-python-client
|
| 23 |
+
|
| 24 |
+
# Database
|
| 25 |
+
pymongo
|
| 26 |
+
google-cloud-storage
|
| 27 |
+
firebase-admin
|
| 28 |
+
|
| 29 |
+
# Visualize
|
| 30 |
+
matplotlib
|
| 31 |
+
seaborn
|
| 32 |
+
|
| 33 |
+
# HuggingFace
|
| 34 |
+
huggingface_hub==0.25.2
|
| 35 |
+
|
| 36 |
+
# Additional dependencies for RLHF training
|
| 37 |
+
pyarrow # For parquet file support
|
static/check.png
ADDED
|
static/edit.png
ADDED
|
static/icon.png
ADDED
|
|
static/index.html
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>OBD-II Data Dashboard</title>
|
| 7 |
+
<link rel="website icon" type="png" href="/static/icon.png" >
|
| 8 |
+
<link rel="stylesheet" href="/static/styles.css">
|
| 9 |
+
</head>
|
| 10 |
+
<body>
|
| 11 |
+
<h1>OBD-II Data Pipeline Monitor</h1>
|
| 12 |
+
<div id="log-container"></div>
|
| 13 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
| 14 |
+
<script src="/static/script.js?v=2"></script>
|
| 15 |
+
</body>
|
| 16 |
+
</html>
|
static/script.js
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const expandedItems = JSON.parse(localStorage.getItem("expandedItems") || "{}");
|
| 2 |
+
const renamedLabels = JSON.parse(localStorage.getItem("renamedLabels") || "{}"); // Allow card to change their name (original identified by ts)
|
| 3 |
+
let previousKeys = [];
|
| 4 |
+
let previousEvents = {}; // Track event status to avoid redundant updates
|
| 5 |
+
|
| 6 |
+
// ─────────────────────────────────────────
|
| 7 |
+
// Refresh event per interval
|
| 8 |
+
// ─────────────────────────────────────────
|
| 9 |
+
async function fetchEvents() {
|
| 10 |
+
const res = await fetch('/events');
|
| 11 |
+
const data = await res.json();
|
| 12 |
+
renderEvents(data);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
// ─────────────────────────────────────────
|
| 16 |
+
// Update or Create new card
|
| 17 |
+
// ─────────────────────────────────────────
|
| 18 |
+
function renderEvents(events) {
|
| 19 |
+
const container = document.getElementById('log-container');
|
| 20 |
+
const currentKeys = Object.keys(events).sort();
|
| 21 |
+
const newlyAdded = currentKeys.find(k => !previousKeys.includes(k));
|
| 22 |
+
previousKeys = currentKeys;
|
| 23 |
+
|
| 24 |
+
currentKeys.forEach(key => {
|
| 25 |
+
const event = events[key];
|
| 26 |
+
const existing = document.getElementById(`card-${key}`);
|
| 27 |
+
const prevStatus = previousEvents[key]?.status;
|
| 28 |
+
|
| 29 |
+
if (!existing) {
|
| 30 |
+
const card = createCard(key, event);
|
| 31 |
+
container.appendChild(card);
|
| 32 |
+
if (key === newlyAdded && event.status === 'done') {
|
| 33 |
+
setTimeout(() => card.scrollIntoView({ behavior: 'smooth', block: 'center' }), 300);
|
| 34 |
+
}
|
| 35 |
+
} else if (event.status !== prevStatus) {
|
| 36 |
+
updateCard(key, event); // Only update if status changed
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
previousEvents[key] = { status: event.status }; // Cache latest status
|
| 40 |
+
});
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
// ─────────────────────────────────────────
|
| 44 |
+
// Create new card on unmatched key
|
| 45 |
+
// ─────────────────────────────────────────
|
| 46 |
+
function createCard(key, event) {
|
| 47 |
+
const readable = renamedLabels[key] || formatTimestamp(key);
|
| 48 |
+
const safeKey = key.replace(/[:.]/g, "-");
|
| 49 |
+
const card = document.createElement('div');
|
| 50 |
+
card.id = `card-${key}`;
|
| 51 |
+
card.className = 'card';
|
| 52 |
+
|
| 53 |
+
const removeBtn = document.createElement('button');
|
| 54 |
+
removeBtn.className = 'btn-remove';
|
| 55 |
+
removeBtn.textContent = 'X';
|
| 56 |
+
removeBtn.onclick = () => removeItem(key);
|
| 57 |
+
|
| 58 |
+
const tsDiv = document.createElement('div');
|
| 59 |
+
tsDiv.className = 'timestamp';
|
| 60 |
+
tsDiv.innerHTML = `<span class="label-text">${readable}</span>`;
|
| 61 |
+
|
| 62 |
+
const editIcon = document.createElement('img');
|
| 63 |
+
editIcon.src = '/static/edit.png';
|
| 64 |
+
editIcon.className = 'icon-edit';
|
| 65 |
+
editIcon.onclick = () => toggleEditMode(tsDiv, key);
|
| 66 |
+
tsDiv.appendChild(editIcon);
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
const statusDiv = document.createElement('div');
|
| 70 |
+
statusDiv.className = 'status';
|
| 71 |
+
|
| 72 |
+
const actionDiv = document.createElement('div');
|
| 73 |
+
actionDiv.className = 'actions';
|
| 74 |
+
|
| 75 |
+
card.appendChild(removeBtn);
|
| 76 |
+
card.appendChild(tsDiv);
|
| 77 |
+
card.appendChild(statusDiv);
|
| 78 |
+
card.appendChild(actionDiv);
|
| 79 |
+
|
| 80 |
+
updateCardContent(card, key, event);
|
| 81 |
+
return card;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// ─────────────────────────────────────────
|
| 85 |
+
// Validate existing card
|
| 86 |
+
// ─────────────────────────────────────────
|
| 87 |
+
function updateCard(key, event) {
|
| 88 |
+
const card = document.getElementById(`card-${key}`);
|
| 89 |
+
if (card) {
|
| 90 |
+
updateCardContent(card, key, event);
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
// ─────────────────────────────────────────
|
| 95 |
+
// Update existing card content
|
| 96 |
+
// ─────────────────────────────────────────
|
| 97 |
+
function updateCardContent(card, key, event) {
|
| 98 |
+
const statusDiv = card.querySelector('.status');
|
| 99 |
+
const actionDiv = card.querySelector('.actions');
|
| 100 |
+
const safeKey = key.replace(/[:.]/g, "-");
|
| 101 |
+
|
| 102 |
+
actionDiv.innerHTML = '';
|
| 103 |
+
if (event.status === 'started') {
|
| 104 |
+
statusDiv.textContent = "Received signal. Data logging started.";
|
| 105 |
+
card.style.backgroundColor = '#780606';
|
| 106 |
+
} else if (event.status === 'processed') {
|
| 107 |
+
statusDiv.textContent = "Data logging finished. Start cleaning process.";
|
| 108 |
+
card.style.backgroundColor = '#2e6930';
|
| 109 |
+
} else if (event.status === 'done') {
|
| 110 |
+
statusDiv.textContent = "Cleaned data saved. Insights is ready.";
|
| 111 |
+
card.style.backgroundColor = '#8a00c2';
|
| 112 |
+
|
| 113 |
+
const expandBtn = document.createElement('button');
|
| 114 |
+
expandBtn.className = 'btn-expand';
|
| 115 |
+
expandBtn.textContent = expandedItems[key] ? 'Collapse' : 'Expand';
|
| 116 |
+
expandBtn.onclick = () => toggleExpand(key, expandBtn);
|
| 117 |
+
|
| 118 |
+
const expandDiv = document.createElement('div');
|
| 119 |
+
expandDiv.id = `expand-${key}`;
|
| 120 |
+
expandDiv.className = 'expanded-content';
|
| 121 |
+
if (expandedItems[key]) expandDiv.classList.add('show');
|
| 122 |
+
|
| 123 |
+
expandDiv.innerHTML = `
|
| 124 |
+
<img src="/plots/heatmap_${safeKey}.png" width="100%">
|
| 125 |
+
<img src="/plots/trend_${safeKey}.png" width="100%">
|
| 126 |
+
`;
|
| 127 |
+
|
| 128 |
+
actionDiv.appendChild(expandBtn);
|
| 129 |
+
actionDiv.appendChild(expandDiv);
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// ─────────────────────────────────────────
|
| 134 |
+
// Toggle card expansion
|
| 135 |
+
// ─────────────────────────────────────────
|
| 136 |
+
function toggleExpand(key, btn) {
|
| 137 |
+
const el = document.getElementById(`expand-${key}`);
|
| 138 |
+
const showing = el.classList.contains('show');
|
| 139 |
+
if (showing) {
|
| 140 |
+
el.classList.remove('show');
|
| 141 |
+
expandedItems[key] = false;
|
| 142 |
+
btn.textContent = 'Expand';
|
| 143 |
+
} else {
|
| 144 |
+
el.classList.add('show');
|
| 145 |
+
expandedItems[key] = true;
|
| 146 |
+
btn.textContent = 'Collapse';
|
| 147 |
+
}
|
| 148 |
+
localStorage.setItem("expandedItems", JSON.stringify(expandedItems));
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
// ─────────────────────────────────────────
|
| 152 |
+
// Toggle card edit-view mode
|
| 153 |
+
// ─────────────────────────────────────────
|
| 154 |
+
function toggleEditMode(container, key) {
|
| 155 |
+
const icon = container.querySelector('.icon-edit');
|
| 156 |
+
if (!container.classList.contains('editing')) {
|
| 157 |
+
const span = container.querySelector('.label-text');
|
| 158 |
+
if (!span) return;
|
| 159 |
+
const input = document.createElement('input');
|
| 160 |
+
input.type = 'text';
|
| 161 |
+
input.value = span.textContent;
|
| 162 |
+
input.className = 'label-input';
|
| 163 |
+
span.replaceWith(input);
|
| 164 |
+
icon.src = '/static/check.png';
|
| 165 |
+
container.classList.add('editing');
|
| 166 |
+
} else {
|
| 167 |
+
const input = container.querySelector('.label-input');
|
| 168 |
+
if (!input) return;
|
| 169 |
+
const newLabel = input.value.trim() || formatTimestamp(key);
|
| 170 |
+
renamedLabels[key] = newLabel;
|
| 171 |
+
localStorage.setItem("renamedLabels", JSON.stringify(renamedLabels));
|
| 172 |
+
const newSpan = document.createElement('span');
|
| 173 |
+
newSpan.className = 'label-text';
|
| 174 |
+
newSpan.textContent = newLabel;
|
| 175 |
+
input.replaceWith(newSpan);
|
| 176 |
+
icon.src = '/static/edit.png';
|
| 177 |
+
container.classList.remove('editing');
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
// ─────────────────────────────────────────
|
| 182 |
+
// Remove a card item
|
| 183 |
+
// ─────────────────────────────────────────
|
| 184 |
+
function removeItem(key) {
|
| 185 |
+
const card = document.getElementById(`card-${key}`);
|
| 186 |
+
if (card) card.remove();
|
| 187 |
+
delete expandedItems[key];
|
| 188 |
+
delete previousEvents[key];
|
| 189 |
+
localStorage.setItem("expandedItems", JSON.stringify(expandedItems));
|
| 190 |
+
fetch(`/events/remove/${key}`, { method: 'DELETE' });
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// ─────────────────────────────────────────
|
| 194 |
+
// Format timestamp as hh:mm dd/mm/yyyy
|
| 195 |
+
// ─────────────────────────────────────────
|
| 196 |
+
function formatTimestamp(norm_ts) {
|
| 197 |
+
try {
|
| 198 |
+
const parts = norm_ts.split("T");
|
| 199 |
+
if (parts.length !== 2) throw new Error("Invalid format");
|
| 200 |
+
// Extract date and time parts
|
| 201 |
+
const datePart = parts[0]; // e.g., "2025-05-21"
|
| 202 |
+
const timeParts = parts[1].split("-"); // ["hh", "mm", "ss"]
|
| 203 |
+
if (timeParts.length < 3) throw new Error("Incomplete time");
|
| 204 |
+
// Reformat
|
| 205 |
+
const [year, month, day] = datePart.split("-").map(Number);
|
| 206 |
+
let [hour, minute, second] = timeParts.map(Number);
|
| 207 |
+
hour = (hour - 2 + 24) % 24;
|
| 208 |
+
// Create Date in local time (note: month is 0-based)
|
| 209 |
+
const dt = new Date(year, month - 1, day, hour, minute, second);
|
| 210 |
+
// Write string
|
| 211 |
+
const timeStr = dt.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' });
|
| 212 |
+
const dateStr = dt.toLocaleDateString('en-AU');
|
| 213 |
+
return `${timeStr} ${dateStr}`;
|
| 214 |
+
} catch (err) {
|
| 215 |
+
console.warn("formatTimestamp fallback:", err.message);
|
| 216 |
+
return norm_ts;
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
// ─────────────────────────────────────────
|
| 222 |
+
// Sanitize filenames from timestamp
|
| 223 |
+
// ─────────────────────────────────────────
|
| 224 |
+
function sanitizeFilename(ts) {
|
| 225 |
+
return ts.replace(/:/g, '-').replace(/ /g, 'T').replace(/\//g, '-');
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
// ─────────────────────────────────────────
|
| 229 |
+
fetchEvents();
|
| 230 |
+
setInterval(fetchEvents, 1000);
|
static/styles.css
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
body {
|
| 2 |
+
font-family: 'Segoe UI', sans-serif;
|
| 3 |
+
background: linear-gradient(to bottom right, #eef1f7, #f9fafe);
|
| 4 |
+
margin: 0;
|
| 5 |
+
padding: 2rem;
|
| 6 |
+
color: #333;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
h1 {
|
| 10 |
+
text-align: center;
|
| 11 |
+
margin-bottom: 2rem;
|
| 12 |
+
font-size: 2rem;
|
| 13 |
+
color: #2c3e50;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
#log-container {
|
| 17 |
+
display: flex;
|
| 18 |
+
flex-direction: column;
|
| 19 |
+
gap: 1.5rem;
|
| 20 |
+
max-width: 960px;
|
| 21 |
+
margin: auto;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/* Card display */
|
| 25 |
+
.card {
|
| 26 |
+
border-radius: 10px;
|
| 27 |
+
padding: 1.2rem 1.5rem;
|
| 28 |
+
color: white;
|
| 29 |
+
position: relative;
|
| 30 |
+
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.08);
|
| 31 |
+
transition: transform 0.3s ease, background-color 0.3s ease;
|
| 32 |
+
overflow: hidden;
|
| 33 |
+
}
|
| 34 |
+
.card:hover {
|
| 35 |
+
transform: translateY(-3px);
|
| 36 |
+
}
|
| 37 |
+
.status {
|
| 38 |
+
font-weight: 600;
|
| 39 |
+
font-size: 1.1rem;
|
| 40 |
+
}
|
| 41 |
+
.timestamp {
|
| 42 |
+
font-size: 0.95rem;
|
| 43 |
+
opacity: 0.9;
|
| 44 |
+
margin-top: 4px;
|
| 45 |
+
display: flex;
|
| 46 |
+
align-items: center;
|
| 47 |
+
gap: 8px;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.icon-edit {
|
| 51 |
+
width: 18px;
|
| 52 |
+
height: 18px;
|
| 53 |
+
cursor: pointer;
|
| 54 |
+
margin-left: 4px;
|
| 55 |
+
}
|
| 56 |
+
.label-input {
|
| 57 |
+
font-size: 1rem;
|
| 58 |
+
padding: 2px 6px;
|
| 59 |
+
border-radius: 4px;
|
| 60 |
+
border: 1px solid #ccc;
|
| 61 |
+
width: 160px;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
/* All buttons */
|
| 66 |
+
.btn-expand,
|
| 67 |
+
.btn-remove {
|
| 68 |
+
margin-top: 1rem;
|
| 69 |
+
padding: 0.4rem 1.2rem;
|
| 70 |
+
cursor: pointer;
|
| 71 |
+
font-size: 0.9rem;
|
| 72 |
+
border: none;
|
| 73 |
+
border-radius: 4px;
|
| 74 |
+
transition: background-color 0.2s ease;
|
| 75 |
+
}
|
| 76 |
+
.btn-expand {
|
| 77 |
+
background-color: rgba(255, 255, 255, 0.25);
|
| 78 |
+
color: white;
|
| 79 |
+
}
|
| 80 |
+
.btn-expand:hover {
|
| 81 |
+
background-color: rgba(255, 255, 255, 0.4);
|
| 82 |
+
}
|
| 83 |
+
.btn-remove {
|
| 84 |
+
position: absolute;
|
| 85 |
+
top: 10px;
|
| 86 |
+
right: 14px;
|
| 87 |
+
background: rgba(255, 255, 255, 0.15);
|
| 88 |
+
color: white;
|
| 89 |
+
}
|
| 90 |
+
.btn-remove:hover {
|
| 91 |
+
background: rgba(255, 255, 255, 0.3);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
/* Expanded content */
|
| 95 |
+
.expanded-content {
|
| 96 |
+
margin-top: 1.2rem;
|
| 97 |
+
animation: fadeIn 0.3s ease-in-out;
|
| 98 |
+
max-height: 0; /* You can adjust this limit */
|
| 99 |
+
overflow-y: auto; /* Allow vertical scroll */
|
| 100 |
+
transition: max-height 0.4s ease-in-out, opacity 0.3s ease;
|
| 101 |
+
opacity: 0;
|
| 102 |
+
padding-right: 5px; /* Optional: give room for scrollbar */
|
| 103 |
+
}
|
| 104 |
+
.expanded-content.show {
|
| 105 |
+
max-height: 1000px;
|
| 106 |
+
opacity: 1;
|
| 107 |
+
}
|
| 108 |
+
.expanded-content img {
|
| 109 |
+
margin-top: 1rem;
|
| 110 |
+
border-radius: 6px;
|
| 111 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/* Colors */
|
| 115 |
+
.card.red {
|
| 116 |
+
background-color: #e74c3c;
|
| 117 |
+
}
|
| 118 |
+
.card.green {
|
| 119 |
+
background-color: #27ae60;
|
| 120 |
+
}
|
| 121 |
+
.card.purple {
|
| 122 |
+
background-color: #8e44ad;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
/* Animation */
|
| 126 |
+
@keyframes fadeIn {
|
| 127 |
+
from {
|
| 128 |
+
opacity: 0;
|
| 129 |
+
transform: translateY(10px);
|
| 130 |
+
}
|
| 131 |
+
to {
|
| 132 |
+
opacity: 1;
|
| 133 |
+
transform: translateY(0);
|
| 134 |
+
}
|
| 135 |
+
}
|
train/README.md
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: tabular-classification
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
# RLHF Training System
|
| 9 |
+
|
| 10 |
+
This directory contains the Reinforcement Learning from Human Feedback (RLHF) training pipeline for the driver behavior classification model.
|
| 11 |
+
|
| 12 |
+
## Overview
|
| 13 |
+
|
| 14 |
+
The RLHF system enables continuous improvement of the driver behavior model by:
|
| 15 |
+
1. Loading human-labeled data from Firebase storage (`skyledge/labeled`)
|
| 16 |
+
2. Combining it with existing model predictions for reinforcement learning
|
| 17 |
+
3. Retraining the XGBoost model with the enhanced dataset
|
| 18 |
+
4. Saving new model checkpoints to Hugging Face Hub
|
| 19 |
+
|
| 20 |
+
## Files
|
| 21 |
+
|
| 22 |
+
### `loader.py`
|
| 23 |
+
- **Purpose**: Load labeled data from Firebase storage
|
| 24 |
+
- **Key Features**:
|
| 25 |
+
- Lists available labeled datasets from `skyledge/labeled` path
|
| 26 |
+
- Tracks already processed datasets in `trained.txt`
|
| 27 |
+
- Downloads and loads datasets into pandas DataFrames
|
| 28 |
+
- Prevents retraining on the same data
|
| 29 |
+
|
| 30 |
+
### `saver.py`
|
| 31 |
+
- **Purpose**: Save trained models to Hugging Face Hub and local storage
|
| 32 |
+
- **Key Features**:
|
| 33 |
+
- Saves model components (XGBoost model, label encoder, scaler)
|
| 34 |
+
- Creates model metadata and README files
|
| 35 |
+
- Uploads to Hugging Face Hub with versioning
|
| 36 |
+
- Maintains local model directory structure
|
| 37 |
+
|
| 38 |
+
### `rlhf.py`
|
| 39 |
+
- **Purpose**: Main RLHF training pipeline
|
| 40 |
+
- **Key Features**:
|
| 41 |
+
- Loads new labeled datasets
|
| 42 |
+
- Creates RLHF dataset by combining labeled data with model predictions
|
| 43 |
+
- Trains XGBoost model with enhanced dataset
|
| 44 |
+
- Evaluates model performance
|
| 45 |
+
- Coordinates with loader and saver modules
|
| 46 |
+
|
| 47 |
+
## API Endpoints
|
| 48 |
+
|
| 49 |
+
The RLHF training system is integrated into the main FastAPI application with the following endpoints:
|
| 50 |
+
|
| 51 |
+
### `POST /rlhf/train`
|
| 52 |
+
Trigger RLHF training session.
|
| 53 |
+
|
| 54 |
+
**Request Body:**
|
| 55 |
+
```json
|
| 56 |
+
{
|
| 57 |
+
"max_datasets": 10,
|
| 58 |
+
"force_retrain": false
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
**Response:**
|
| 63 |
+
```json
|
| 64 |
+
{
|
| 65 |
+
"status": "success",
|
| 66 |
+
"model_version": "20241201_143022",
|
| 67 |
+
"datasets_processed": 5,
|
| 68 |
+
"samples_processed": 1250,
|
| 69 |
+
"performance_metrics": {
|
| 70 |
+
"accuracy": 0.892,
|
| 71 |
+
"cv_mean": 0.885,
|
| 72 |
+
"cv_std": 0.012
|
| 73 |
+
},
|
| 74 |
+
"timestamp": "2024-12-01T14:30:22"
|
| 75 |
+
}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### `GET /rlhf/status`
|
| 79 |
+
Get status of RLHF training system and available labeled data.
|
| 80 |
+
|
| 81 |
+
### `GET /rlhf/trained-datasets`
|
| 82 |
+
Get list of datasets that have already been used for training.
|
| 83 |
+
|
| 84 |
+
## Configuration
|
| 85 |
+
|
| 86 |
+
### Environment Variables
|
| 87 |
+
- `HF_TOKEN`: Hugging Face authentication token
|
| 88 |
+
- `HF_MODEL_REPO`: Hugging Face model repository (default: `BinKhoaLe1812/Driver_Behavior_OBD`)
|
| 89 |
+
- `MODEL_DIR`: Local model directory (default: `/app/models/ul`)
|
| 90 |
+
- `FIREBASE_ADMIN_JSON`: Firebase Admin SDK credentials
|
| 91 |
+
- `FIREBASE_SERVICE_ACCOUNT_JSON`: Firebase service account credentials
|
| 92 |
+
|
| 93 |
+
### Firebase Storage Structure
|
| 94 |
+
```
|
| 95 |
+
skyledge-36b56.firebasestorage.app/
|
| 96 |
+
├── skyledge/
|
| 97 |
+
│ ├── processed/ # Original processed data
|
| 98 |
+
│ ├── labeled/ # Human-labeled data for RLHF
|
| 99 |
+
│ │ ├── dataset1.csv
|
| 100 |
+
│ │ ├── dataset2.csv
|
| 101 |
+
│ │ └── trained.txt # Tracks processed datasets
|
| 102 |
+
│ └── logs/ # Training logs (future)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Usage
|
| 106 |
+
|
| 107 |
+
## Model Versioning
|
| 108 |
+
|
| 109 |
+
Each training session creates a new model version with timestamp format: `YYYYMMDD_HHMMSS`
|
| 110 |
+
|
| 111 |
+
Models are saved to:
|
| 112 |
+
- **Local**: `/app/models/ul/v{version}/`
|
| 113 |
+
- **Hugging Face**: `BinKhoaLe1812/Driver_Behavior_OBD`
|
| 114 |
+
|
| 115 |
+
## Data Flow
|
| 116 |
+
|
| 117 |
+
1. **Data Collection**: Human-labeled data stored in `skyledge/labeled/`
|
| 118 |
+
2. **Training Trigger**: API endpoint or manual trigger
|
| 119 |
+
3. **Data Loading**: Load new labeled datasets (skip already processed)
|
| 120 |
+
4. **RLHF Dataset**: Combine labeled data with model predictions
|
| 121 |
+
5. **Model Training**: Train XGBoost with enhanced dataset
|
| 122 |
+
6. **Evaluation**: Calculate performance metrics
|
| 123 |
+
7. **Model Saving**: Save to local storage and Hugging Face Hub
|
| 124 |
+
8. **Tracking**: Update `trained.txt` with processed datasets
|
| 125 |
+
|
| 126 |
+
## Performance Monitoring
|
| 127 |
+
|
| 128 |
+
The system tracks:
|
| 129 |
+
- Number of datasets processed
|
| 130 |
+
- Total samples processed
|
| 131 |
+
- Model accuracy and cross-validation scores
|
| 132 |
+
- Training timestamps and metadata
|
| 133 |
+
|
| 134 |
+
## Error Handling
|
| 135 |
+
|
| 136 |
+
- Graceful handling of missing datasets
|
| 137 |
+
- Firebase connection failures
|
| 138 |
+
- Model loading/saving errors
|
| 139 |
+
- XGBoost compatibility issues
|
| 140 |
+
- Comprehensive logging throughout the pipeline
|
train/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train package
|
| 2 |
+
# RLHF Training System for Driver Behavior Classification
|
| 3 |
+
|
| 4 |
+
from .rlhf import RLHFTrainer
|
| 5 |
+
from .loader import LabeledDataLoader
|
| 6 |
+
from .saver import ModelSaver
|
| 7 |
+
|
| 8 |
+
__all__ = ['RLHFTrainer', 'LabeledDataLoader', 'ModelSaver']
|
train/loader.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# loader.py
|
| 2 |
+
# Load labeled data from Firebase storage for RLHF training
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import List, Dict, Optional, Tuple, Any
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
# Import Firebase client from the existing firebase_saver
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 14 |
+
from data.firebase_saver import _AdminClient, _GCSClient
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("rlhf-loader")
|
| 17 |
+
logger.setLevel(logging.INFO)
|
| 18 |
+
if not logger.handlers:
|
| 19 |
+
_h = logging.StreamHandler()
|
| 20 |
+
_h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
|
| 21 |
+
logger.addHandler(_h)
|
| 22 |
+
|
| 23 |
+
# Firebase configuration
|
| 24 |
+
FIREBASE_BUCKET = "skyledge-36b56.firebasestorage.app"
|
| 25 |
+
LABELED_PREFIX = "skyledge/labeled"
|
| 26 |
+
RAW_PREFIX = "skyledge/raw"
|
| 27 |
+
PROCESSED_PREFIX = "skyledge/processed"
|
| 28 |
+
TRAINED_FILE = "trained.txt"
|
| 29 |
+
|
| 30 |
+
class LabeledDataLoader:
|
| 31 |
+
"""
|
| 32 |
+
Load labeled data from Firebase storage for RLHF training.
|
| 33 |
+
Tracks already processed datasets to avoid retraining on the same data.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.bucket_name = FIREBASE_BUCKET
|
| 38 |
+
self.prefix = LABELED_PREFIX
|
| 39 |
+
self.trained_file = TRAINED_FILE
|
| 40 |
+
|
| 41 |
+
# Initialize Firebase client
|
| 42 |
+
self.client = None
|
| 43 |
+
self.mode = None
|
| 44 |
+
try:
|
| 45 |
+
if os.getenv("FIREBASE_ADMIN_JSON"):
|
| 46 |
+
self.client = _AdminClient(self.bucket_name)
|
| 47 |
+
self.mode = "admin"
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.warning(f"⚠️ Admin SDK init failed: {e}")
|
| 50 |
+
|
| 51 |
+
if self.client is None:
|
| 52 |
+
try:
|
| 53 |
+
self.client = _GCSClient(self.bucket_name)
|
| 54 |
+
self.mode = "gcs"
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"❌ GCS client init failed: {e}")
|
| 57 |
+
raise
|
| 58 |
+
|
| 59 |
+
logger.info(f"📦 LabeledDataLoader ready | mode={self.mode} bucket={self.bucket_name} prefix={self.prefix}")
|
| 60 |
+
|
| 61 |
+
def _get_trained_datasets(self) -> List[str]:
|
| 62 |
+
"""Load list of already trained datasets from trained.txt"""
|
| 63 |
+
try:
|
| 64 |
+
# Check if trained.txt exists in Firebase storage
|
| 65 |
+
trained_path = f"{self.prefix}/{self.trained_file}"
|
| 66 |
+
if self.client.blob_exists(trained_path):
|
| 67 |
+
# Download and read the file
|
| 68 |
+
blob = self.client.bucket.blob(trained_path)
|
| 69 |
+
content = blob.download_as_text()
|
| 70 |
+
trained_datasets = [line.strip() for line in content.split('\n') if line.strip()]
|
| 71 |
+
logger.info(f"📋 Loaded {len(trained_datasets)} already trained datasets")
|
| 72 |
+
return trained_datasets
|
| 73 |
+
else:
|
| 74 |
+
logger.info("📋 No trained.txt found, starting fresh")
|
| 75 |
+
return []
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"⚠️ Failed to load trained datasets: {e}")
|
| 78 |
+
return []
|
| 79 |
+
|
| 80 |
+
def _update_trained_datasets(self, new_datasets: List[str]):
|
| 81 |
+
"""Update trained.txt with new dataset names"""
|
| 82 |
+
try:
|
| 83 |
+
# Get existing trained datasets
|
| 84 |
+
existing = self._get_trained_datasets()
|
| 85 |
+
|
| 86 |
+
# Add new datasets with timestamp
|
| 87 |
+
timestamp = datetime.now().isoformat()
|
| 88 |
+
new_entries = [f"{timestamp}:{dataset}" for dataset in new_datasets]
|
| 89 |
+
all_entries = existing + new_entries
|
| 90 |
+
|
| 91 |
+
# Upload updated file
|
| 92 |
+
trained_path = f"{self.prefix}/{self.trained_file}"
|
| 93 |
+
content = '\n'.join(all_entries)
|
| 94 |
+
self.client.upload_from_bytes(
|
| 95 |
+
content.encode('utf-8'),
|
| 96 |
+
trained_path,
|
| 97 |
+
"text/plain"
|
| 98 |
+
)
|
| 99 |
+
logger.info(f"✅ Updated trained.txt with {len(new_datasets)} new datasets")
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"❌ Failed to update trained datasets: {e}")
|
| 102 |
+
|
| 103 |
+
def list_labeled_datasets(self) -> List[Dict[str, str]]:
|
| 104 |
+
"""List all available labeled datasets in Firebase storage"""
|
| 105 |
+
try:
|
| 106 |
+
# List all blobs under the labeled prefix
|
| 107 |
+
blobs = self.client.bucket.list_blobs(prefix=f"{self.prefix}/")
|
| 108 |
+
|
| 109 |
+
datasets = []
|
| 110 |
+
trained_datasets = self._get_trained_datasets()
|
| 111 |
+
|
| 112 |
+
for blob in blobs:
|
| 113 |
+
# Skip the trained.txt file itself
|
| 114 |
+
if blob.name.endswith(f"/{self.trained_file}"):
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Extract dataset name (relative to skyledge root)
|
| 118 |
+
dataset_name = blob.name.replace("skyledge/", "")
|
| 119 |
+
|
| 120 |
+
# Skip if already trained
|
| 121 |
+
if any(dataset_name in entry for entry in trained_datasets):
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# Get blob metadata
|
| 125 |
+
blob.reload()
|
| 126 |
+
datasets.append({
|
| 127 |
+
'name': dataset_name,
|
| 128 |
+
'path': blob.name,
|
| 129 |
+
'size': blob.size,
|
| 130 |
+
'created': blob.time_created.isoformat() if blob.time_created else None,
|
| 131 |
+
'updated': blob.updated.isoformat() if blob.updated else None,
|
| 132 |
+
'content_type': blob.content_type
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
logger.info(f"📊 Found {len(datasets)} new labeled datasets")
|
| 136 |
+
return datasets
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"❌ Failed to list labeled datasets: {e}")
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
def download_dataset(self, dataset_path: str, local_path: str) -> bool:
|
| 143 |
+
"""Download a dataset from Firebase storage to local path"""
|
| 144 |
+
try:
|
| 145 |
+
blob = self.client.bucket.blob(dataset_path)
|
| 146 |
+
blob.download_to_filename(local_path)
|
| 147 |
+
logger.info(f"✅ Downloaded {dataset_path} to {local_path}")
|
| 148 |
+
return True
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"❌ Failed to download {dataset_path}: {e}")
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def load_dataset(self, dataset_path: str) -> Optional[pd.DataFrame]:
|
| 154 |
+
"""Load a dataset directly into a pandas DataFrame"""
|
| 155 |
+
try:
|
| 156 |
+
blob = self.client.bucket.blob(dataset_path)
|
| 157 |
+
content = blob.download_as_text()
|
| 158 |
+
|
| 159 |
+
# Try to determine file type and load accordingly
|
| 160 |
+
if dataset_path.endswith('.csv'):
|
| 161 |
+
df = pd.read_csv(pd.StringIO(content))
|
| 162 |
+
elif dataset_path.endswith('.json'):
|
| 163 |
+
df = pd.read_json(pd.StringIO(content))
|
| 164 |
+
elif dataset_path.endswith('.parquet'):
|
| 165 |
+
# For parquet, we need to download as bytes
|
| 166 |
+
blob_bytes = blob.download_as_bytes()
|
| 167 |
+
df = pd.read_parquet(pd.BytesIO(blob_bytes))
|
| 168 |
+
else:
|
| 169 |
+
# Default to CSV
|
| 170 |
+
df = pd.read_csv(pd.StringIO(content))
|
| 171 |
+
|
| 172 |
+
logger.info(f"✅ Loaded dataset {dataset_path} with shape {df.shape}")
|
| 173 |
+
return df
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"❌ Failed to load dataset {dataset_path}: {e}")
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
def get_new_datasets_for_training(self) -> List[Dict[str, str]]:
|
| 180 |
+
"""Get list of new datasets that haven't been used for training yet"""
|
| 181 |
+
return self.list_labeled_datasets()
|
| 182 |
+
|
| 183 |
+
def mark_datasets_as_trained(self, dataset_names: List[str]):
|
| 184 |
+
"""Mark datasets as trained to avoid retraining"""
|
| 185 |
+
self._update_trained_datasets(dataset_names)
|
| 186 |
+
|
| 187 |
+
def _parse_labeled_filename(self, filename: str) -> Dict[str, str]:
|
| 188 |
+
"""
|
| 189 |
+
Parse labeled filename to extract original dataset information.
|
| 190 |
+
Format: {id}_{source}-{original_id}_{date}-labelled.csv
|
| 191 |
+
Example: 001_raw-002_2025-09-19-labelled.csv
|
| 192 |
+
"""
|
| 193 |
+
try:
|
| 194 |
+
# Remove .csv extension
|
| 195 |
+
name = filename.replace('.csv', '')
|
| 196 |
+
|
| 197 |
+
# Split by underscore to get parts
|
| 198 |
+
parts = name.split('_')
|
| 199 |
+
if len(parts) < 4:
|
| 200 |
+
return {"error": f"Invalid filename format: {filename}"}
|
| 201 |
+
|
| 202 |
+
# Extract components
|
| 203 |
+
labeled_id = parts[0] # 001
|
| 204 |
+
source_and_original = parts[1] # raw-002 or processed-002
|
| 205 |
+
date = parts[2] # 2025-09-19
|
| 206 |
+
|
| 207 |
+
# Parse source and original ID
|
| 208 |
+
if '-' in source_and_original:
|
| 209 |
+
source, original_id = source_and_original.split('-', 1)
|
| 210 |
+
else:
|
| 211 |
+
source = source_and_original
|
| 212 |
+
original_id = "unknown"
|
| 213 |
+
|
| 214 |
+
return {
|
| 215 |
+
"labeled_id": labeled_id,
|
| 216 |
+
"source": source, # raw or processed
|
| 217 |
+
"original_id": original_id,
|
| 218 |
+
"date": date,
|
| 219 |
+
"original_filename": f"{original_id}_{date}-{source}.csv" if source != "unknown" else None
|
| 220 |
+
}
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.warning(f"⚠️ Failed to parse filename {filename}: {e}")
|
| 223 |
+
return {"error": str(e)}
|
| 224 |
+
|
| 225 |
+
def _find_original_dataset(self, labeled_info: Dict[str, str]) -> Optional[str]:
|
| 226 |
+
"""Find the original dataset path based on labeled file info"""
|
| 227 |
+
if labeled_info.get("error") or not labeled_info.get("original_filename"):
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
source = labeled_info["source"]
|
| 231 |
+
original_filename = labeled_info["original_filename"]
|
| 232 |
+
|
| 233 |
+
if source == "raw":
|
| 234 |
+
return f"{self.RAW_PREFIX}/{original_filename}"
|
| 235 |
+
elif source == "processed":
|
| 236 |
+
return f"{self.PROCESSED_PREFIX}/{original_filename}"
|
| 237 |
+
else:
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
def load_labeled_with_original(self, labeled_path: str) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Dict[str, str]]:
|
| 241 |
+
"""
|
| 242 |
+
Load labeled dataset along with its original dataset for RLHF comparison.
|
| 243 |
+
Returns: (labeled_df, original_df, metadata)
|
| 244 |
+
"""
|
| 245 |
+
try:
|
| 246 |
+
# Load labeled dataset
|
| 247 |
+
labeled_df = self.load_dataset(labeled_path)
|
| 248 |
+
if labeled_df is None:
|
| 249 |
+
return None, None, {"error": "Failed to load labeled dataset"}
|
| 250 |
+
|
| 251 |
+
# Parse filename to get original dataset info
|
| 252 |
+
filename = labeled_path.split('/')[-1]
|
| 253 |
+
labeled_info = self._parse_labeled_filename(filename)
|
| 254 |
+
|
| 255 |
+
if labeled_info.get("error"):
|
| 256 |
+
logger.warning(f"⚠️ Could not parse labeled filename: {labeled_info['error']}")
|
| 257 |
+
return labeled_df, None, labeled_info
|
| 258 |
+
|
| 259 |
+
# Find and load original dataset
|
| 260 |
+
original_path = self._find_original_dataset(labeled_info)
|
| 261 |
+
original_df = None
|
| 262 |
+
|
| 263 |
+
if original_path and self.client.blob_exists(original_path):
|
| 264 |
+
original_df = self.load_dataset(original_path)
|
| 265 |
+
if original_df is not None:
|
| 266 |
+
logger.info(f"✅ Loaded original dataset: {original_path}")
|
| 267 |
+
else:
|
| 268 |
+
logger.warning(f"⚠️ Failed to load original dataset: {original_path}")
|
| 269 |
+
else:
|
| 270 |
+
logger.warning(f"⚠️ Original dataset not found: {original_path}")
|
| 271 |
+
|
| 272 |
+
return labeled_df, original_df, labeled_info
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
logger.error(f"❌ Failed to load labeled with original: {e}")
|
| 276 |
+
return None, None, {"error": str(e)}
|
| 277 |
+
|
| 278 |
+
def create_training_batch(self, max_datasets: int = 10) -> Tuple[List[pd.DataFrame], List[str]]:
|
| 279 |
+
"""
|
| 280 |
+
Create a training batch by loading new datasets.
|
| 281 |
+
Returns tuple of (dataframes, dataset_names)
|
| 282 |
+
"""
|
| 283 |
+
datasets = self.get_new_datasets_for_training()
|
| 284 |
+
|
| 285 |
+
if not datasets:
|
| 286 |
+
logger.info("📭 No new datasets available for training")
|
| 287 |
+
return [], []
|
| 288 |
+
|
| 289 |
+
# Limit the number of datasets
|
| 290 |
+
datasets = datasets[:max_datasets]
|
| 291 |
+
|
| 292 |
+
dataframes = []
|
| 293 |
+
dataset_names = []
|
| 294 |
+
|
| 295 |
+
for dataset in datasets:
|
| 296 |
+
df = self.load_dataset(dataset['path'])
|
| 297 |
+
if df is not None:
|
| 298 |
+
dataframes.append(df)
|
| 299 |
+
dataset_names.append(dataset['name'])
|
| 300 |
+
else:
|
| 301 |
+
logger.warning(f"⚠️ Skipping dataset {dataset['name']} due to load failure")
|
| 302 |
+
|
| 303 |
+
if dataframes:
|
| 304 |
+
logger.info(f"📦 Created training batch with {len(dataframes)} datasets")
|
| 305 |
+
# Mark these datasets as trained
|
| 306 |
+
self.mark_datasets_as_trained(dataset_names)
|
| 307 |
+
|
| 308 |
+
return dataframes, dataset_names
|
| 309 |
+
|
| 310 |
+
def create_rlhf_training_batch(self, max_datasets: int = 10) -> Tuple[List[Dict[str, Any]], List[str]]:
|
| 311 |
+
"""
|
| 312 |
+
Create RLHF training batch with both labeled and original datasets.
|
| 313 |
+
Returns tuple of (training_data, dataset_names)
|
| 314 |
+
Each training_data item contains: {'labeled_df', 'original_df', 'metadata'}
|
| 315 |
+
"""
|
| 316 |
+
datasets = self.get_new_datasets_for_training()
|
| 317 |
+
|
| 318 |
+
if not datasets:
|
| 319 |
+
logger.info("📭 No new datasets available for RLHF training")
|
| 320 |
+
return [], []
|
| 321 |
+
|
| 322 |
+
# Limit the number of datasets
|
| 323 |
+
datasets = datasets[:max_datasets]
|
| 324 |
+
|
| 325 |
+
training_data = []
|
| 326 |
+
dataset_names = []
|
| 327 |
+
|
| 328 |
+
for dataset in datasets:
|
| 329 |
+
labeled_df, original_df, metadata = self.load_labeled_with_original(dataset['path'])
|
| 330 |
+
|
| 331 |
+
if labeled_df is not None:
|
| 332 |
+
training_item = {
|
| 333 |
+
'labeled_df': labeled_df,
|
| 334 |
+
'original_df': original_df,
|
| 335 |
+
'metadata': metadata,
|
| 336 |
+
'dataset_name': dataset['name']
|
| 337 |
+
}
|
| 338 |
+
training_data.append(training_item)
|
| 339 |
+
dataset_names.append(dataset['name'])
|
| 340 |
+
logger.info(f"✅ Loaded RLHF dataset: {dataset['name']} (original: {metadata.get('original_filename', 'N/A')})")
|
| 341 |
+
else:
|
| 342 |
+
logger.warning(f"⚠️ Skipping dataset {dataset['name']} due to load failure")
|
| 343 |
+
|
| 344 |
+
if training_data:
|
| 345 |
+
logger.info(f"📦 Created RLHF training batch with {len(training_data)} datasets")
|
| 346 |
+
# Mark these datasets as trained
|
| 347 |
+
self.mark_datasets_as_trained(dataset_names)
|
| 348 |
+
|
| 349 |
+
return training_data, dataset_names
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def main():
|
| 353 |
+
"""Test the loader functionality"""
|
| 354 |
+
loader = LabeledDataLoader()
|
| 355 |
+
|
| 356 |
+
# List available datasets
|
| 357 |
+
datasets = loader.list_labeled_datasets()
|
| 358 |
+
print(f"Available datasets: {len(datasets)}")
|
| 359 |
+
for dataset in datasets:
|
| 360 |
+
print(f" - {dataset['name']} ({dataset['size']} bytes)")
|
| 361 |
+
|
| 362 |
+
# Create a training batch
|
| 363 |
+
dataframes, names = loader.create_training_batch(max_datasets=5)
|
| 364 |
+
print(f"Training batch: {len(dataframes)} datasets")
|
| 365 |
+
for name in names:
|
| 366 |
+
print(f" - {name}")
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
if __name__ == "__main__":
|
| 370 |
+
main()
|
train/rlhf.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rlhf.py
|
| 2 |
+
# Reinforcement Learning from Human Feedback training pipeline
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import pickle
|
| 7 |
+
import joblib
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
from sklearn.model_selection import train_test_split, cross_val_score
|
| 15 |
+
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
|
| 16 |
+
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
| 17 |
+
import xgboost as xgb
|
| 18 |
+
|
| 19 |
+
# Import our custom modules
|
| 20 |
+
from .loader import LabeledDataLoader
|
| 21 |
+
from .saver import ModelSaver
|
| 22 |
+
|
| 23 |
+
# Suppress warnings for cleaner output
|
| 24 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
|
| 25 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
|
| 26 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger("rlhf-trainer")
|
| 29 |
+
logger.setLevel(logging.INFO)
|
| 30 |
+
if not logger.handlers:
|
| 31 |
+
_h = logging.StreamHandler()
|
| 32 |
+
_h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
|
| 33 |
+
logger.addHandler(_h)
|
| 34 |
+
|
| 35 |
+
class RLHFTrainer:
|
| 36 |
+
"""
|
| 37 |
+
Reinforcement Learning from Human Feedback trainer for driver behavior classification.
|
| 38 |
+
|
| 39 |
+
This trainer:
|
| 40 |
+
1. Loads human-labeled data from Firebase storage
|
| 41 |
+
2. Combines it with existing model predictions for RLHF
|
| 42 |
+
3. Retrains the XGBoost model with the combined dataset
|
| 43 |
+
4. Evaluates performance and saves the new model
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
self.loader = LabeledDataLoader()
|
| 48 |
+
self.saver = ModelSaver()
|
| 49 |
+
|
| 50 |
+
# Model parameters
|
| 51 |
+
self.model_params = {
|
| 52 |
+
'n_estimators': 100,
|
| 53 |
+
'max_depth': 6,
|
| 54 |
+
'learning_rate': 0.1,
|
| 55 |
+
'subsample': 0.8,
|
| 56 |
+
'colsample_bytree': 0.8,
|
| 57 |
+
'random_state': 42,
|
| 58 |
+
'use_label_encoder': False,
|
| 59 |
+
'eval_metric': 'mlogloss'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Feature columns to drop (non-predictive)
|
| 63 |
+
self.safe_drop = {
|
| 64 |
+
"timestamp", "driving_style", "ul_drivestyle", "gt_drivestyle",
|
| 65 |
+
"session_id", "imported_at", "record_index"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
logger.info("🤖 RLHFTrainer initialized")
|
| 69 |
+
|
| 70 |
+
def _prepare_features(self, df: pd.DataFrame, expected_features: Optional[List[str]] = None) -> Tuple[np.ndarray, List[str]]:
|
| 71 |
+
"""Prepare features for training"""
|
| 72 |
+
# Select numeric columns and drop non-feature columns
|
| 73 |
+
feature_cols = [c for c in df.columns
|
| 74 |
+
if c not in self.safe_drop and pd.api.types.is_numeric_dtype(df[c])]
|
| 75 |
+
|
| 76 |
+
X = df[feature_cols].copy()
|
| 77 |
+
|
| 78 |
+
# Ensure required features are present
|
| 79 |
+
if expected_features:
|
| 80 |
+
for col in expected_features:
|
| 81 |
+
if col not in X.columns:
|
| 82 |
+
X[col] = 0.0
|
| 83 |
+
X = X[expected_features] # Align order
|
| 84 |
+
|
| 85 |
+
# Handle missing values
|
| 86 |
+
X = X.fillna(0)
|
| 87 |
+
|
| 88 |
+
return X.values, feature_cols
|
| 89 |
+
|
| 90 |
+
def _prepare_labels(self, df: pd.DataFrame, label_column: str = "driving_style") -> np.ndarray:
|
| 91 |
+
"""Prepare labels for training"""
|
| 92 |
+
if label_column not in df.columns:
|
| 93 |
+
raise ValueError(f"Label column '{label_column}' not found in data")
|
| 94 |
+
|
| 95 |
+
return df[label_column].values
|
| 96 |
+
|
| 97 |
+
def _load_existing_model(self) -> Tuple[Any, Any, Any, List[str]]:
|
| 98 |
+
"""Load existing model components, downloading latest version if needed"""
|
| 99 |
+
try:
|
| 100 |
+
# First, try to download the latest model
|
| 101 |
+
logger.info("🔄 Checking for latest model version...")
|
| 102 |
+
try:
|
| 103 |
+
from utils.download import download_latest_models
|
| 104 |
+
download_latest_models()
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.warning(f"⚠️ Failed to download latest models: {e}")
|
| 107 |
+
|
| 108 |
+
model_dir = os.getenv("MODEL_DIR", "/app/models/ul")
|
| 109 |
+
|
| 110 |
+
model_path = os.path.join(model_dir, "xgb_drivestyle_ul.pkl")
|
| 111 |
+
le_path = os.path.join(model_dir, "label_encoder_ul.pkl")
|
| 112 |
+
scaler_path = os.path.join(model_dir, "scaler_ul.pkl")
|
| 113 |
+
|
| 114 |
+
# Load with compatibility fixes
|
| 115 |
+
model = self._load_model_with_compatibility(model_path)
|
| 116 |
+
label_encoder = joblib.load(le_path)
|
| 117 |
+
scaler = joblib.load(scaler_path)
|
| 118 |
+
|
| 119 |
+
# Get expected features
|
| 120 |
+
expected_features = None
|
| 121 |
+
if hasattr(scaler, "feature_names_in_"):
|
| 122 |
+
expected_features = list(scaler.feature_names_in_)
|
| 123 |
+
elif hasattr(model, "feature_names_in_"):
|
| 124 |
+
expected_features = list(model.feature_names_in_)
|
| 125 |
+
|
| 126 |
+
logger.info(f"✅ Loaded existing model with {len(expected_features) if expected_features else 'unknown'} features")
|
| 127 |
+
return model, label_encoder, scaler, expected_features
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.warning(f"⚠️ Failed to load existing model: {e}")
|
| 131 |
+
return None, None, None, None
|
| 132 |
+
|
| 133 |
+
def _load_model_with_compatibility(self, model_path: str) -> Any:
|
| 134 |
+
"""Load model with XGBoost compatibility fixes"""
|
| 135 |
+
try:
|
| 136 |
+
model = joblib.load(model_path)
|
| 137 |
+
|
| 138 |
+
# Fix XGBoost compatibility issues
|
| 139 |
+
if hasattr(model, 'get_booster'): # This is an XGBoost model
|
| 140 |
+
# Remove deprecated attributes
|
| 141 |
+
deprecated_attrs = [
|
| 142 |
+
'use_label_encoder', '_le', '_label_encoder',
|
| 143 |
+
'use_label_encoder_', '_le_', '_label_encoder_'
|
| 144 |
+
]
|
| 145 |
+
for attr in deprecated_attrs:
|
| 146 |
+
if hasattr(model, attr):
|
| 147 |
+
try:
|
| 148 |
+
delattr(model, attr)
|
| 149 |
+
except (AttributeError, TypeError):
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
# Set use_label_encoder to False
|
| 153 |
+
if hasattr(model, 'set_params'):
|
| 154 |
+
try:
|
| 155 |
+
model.set_params(use_label_encoder=False)
|
| 156 |
+
except Exception:
|
| 157 |
+
pass
|
| 158 |
+
|
| 159 |
+
return model
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"❌ Failed to load model: {e}")
|
| 163 |
+
raise
|
| 164 |
+
|
| 165 |
+
def _create_rlhf_dataset(self, training_data: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
|
| 166 |
+
"""Create RLHF dataset by combining labeled data with original data and model predictions"""
|
| 167 |
+
try:
|
| 168 |
+
# Load existing model for generating predictions
|
| 169 |
+
existing_model, label_encoder, scaler, expected_features = self._load_existing_model()
|
| 170 |
+
|
| 171 |
+
if existing_model is None:
|
| 172 |
+
logger.warning("⚠️ No existing model found, using only labeled data")
|
| 173 |
+
return self._prepare_rlhf_from_labeled_only(training_data)
|
| 174 |
+
|
| 175 |
+
# Combine all labeled datasets
|
| 176 |
+
labeled_dfs = [item['labeled_df'] for item in training_data if item['labeled_df'] is not None]
|
| 177 |
+
original_dfs = [item['original_df'] for item in training_data if item['original_df'] is not None]
|
| 178 |
+
|
| 179 |
+
combined_labeled_df = pd.concat(labeled_dfs, ignore_index=True)
|
| 180 |
+
|
| 181 |
+
# Prepare features and labels from labeled data
|
| 182 |
+
X_labeled, feature_cols = self._prepare_features(combined_labeled_df, expected_features)
|
| 183 |
+
y_labeled = self._prepare_labels(combined_labeled_df)
|
| 184 |
+
|
| 185 |
+
# Scale features
|
| 186 |
+
X_labeled_scaled = scaler.transform(X_labeled)
|
| 187 |
+
|
| 188 |
+
# Generate model predictions on original data for comparison
|
| 189 |
+
model_predictions = []
|
| 190 |
+
prediction_confidence = []
|
| 191 |
+
|
| 192 |
+
if original_dfs:
|
| 193 |
+
combined_original_df = pd.concat(original_dfs, ignore_index=True)
|
| 194 |
+
X_original, _ = self._prepare_features(combined_original_df, expected_features)
|
| 195 |
+
X_original_scaled = scaler.transform(X_original)
|
| 196 |
+
|
| 197 |
+
# Get model predictions on original data
|
| 198 |
+
original_predictions = existing_model.predict(X_original_scaled)
|
| 199 |
+
model_predictions.extend(original_predictions)
|
| 200 |
+
|
| 201 |
+
# Get prediction probabilities for confidence
|
| 202 |
+
if hasattr(existing_model, 'predict_proba'):
|
| 203 |
+
proba = existing_model.predict_proba(X_original_scaled)
|
| 204 |
+
confidence = np.max(proba, axis=1)
|
| 205 |
+
prediction_confidence.extend(confidence)
|
| 206 |
+
|
| 207 |
+
# Create RLHF dataset with preference learning
|
| 208 |
+
# The labeled data represents the "correct" behavior (human preference)
|
| 209 |
+
# The model predictions on original data represent what the model thought was correct
|
| 210 |
+
|
| 211 |
+
# For RLHF, we want to learn from the difference between model predictions and human labels
|
| 212 |
+
rlhf_metadata = {
|
| 213 |
+
"labeled_samples": len(X_labeled),
|
| 214 |
+
"original_samples": len(model_predictions) if model_predictions else 0,
|
| 215 |
+
"model_confidence": np.mean(prediction_confidence) if prediction_confidence else 0.0,
|
| 216 |
+
"datasets_processed": len(training_data)
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
logger.info(f"📊 Created RLHF dataset: {len(X_labeled)} labeled samples, {len(model_predictions)} original samples")
|
| 220 |
+
logger.info(f"📊 Model confidence on original data: {rlhf_metadata['model_confidence']:.3f}")
|
| 221 |
+
|
| 222 |
+
return X_labeled_scaled, y_labeled, rlhf_metadata
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"❌ Failed to create RLHF dataset: {e}")
|
| 226 |
+
raise
|
| 227 |
+
|
| 228 |
+
def _prepare_rlhf_from_labeled_only(self, training_data: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
|
| 229 |
+
"""Prepare RLHF dataset from labeled data only (when no existing model)"""
|
| 230 |
+
labeled_dfs = [item['labeled_df'] for item in training_data if item['labeled_df'] is not None]
|
| 231 |
+
combined_df = pd.concat(labeled_dfs, ignore_index=True)
|
| 232 |
+
|
| 233 |
+
# Prepare features
|
| 234 |
+
X, feature_cols = self._prepare_features(combined_df)
|
| 235 |
+
y = self._prepare_labels(combined_df)
|
| 236 |
+
|
| 237 |
+
# Create and fit scaler
|
| 238 |
+
scaler = StandardScaler()
|
| 239 |
+
X_scaled = scaler.fit_transform(X)
|
| 240 |
+
|
| 241 |
+
rlhf_metadata = {
|
| 242 |
+
"labeled_samples": len(X),
|
| 243 |
+
"original_samples": 0,
|
| 244 |
+
"model_confidence": 0.0,
|
| 245 |
+
"datasets_processed": len(training_data)
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
return X_scaled, y, rlhf_metadata
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _train_model(self, X: np.ndarray, y: np.ndarray,
|
| 252 |
+
existing_model: Optional[Any] = None) -> Tuple[Any, Any, Any]:
|
| 253 |
+
"""Train the XGBoost model"""
|
| 254 |
+
try:
|
| 255 |
+
# Create label encoder
|
| 256 |
+
label_encoder = LabelEncoder()
|
| 257 |
+
y_encoded = label_encoder.fit_transform(y)
|
| 258 |
+
|
| 259 |
+
# Create scaler
|
| 260 |
+
scaler = StandardScaler()
|
| 261 |
+
X_scaled = scaler.fit_transform(X)
|
| 262 |
+
|
| 263 |
+
# Split data
|
| 264 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 265 |
+
X_scaled, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Create and train model
|
| 269 |
+
model = xgb.XGBClassifier(**self.model_params)
|
| 270 |
+
|
| 271 |
+
# If we have an existing model, we can use it for warm start or transfer learning
|
| 272 |
+
if existing_model is not None:
|
| 273 |
+
logger.info("🔄 Using existing model for warm start")
|
| 274 |
+
# For XGBoost, we can't directly warm start, but we can use similar parameters
|
| 275 |
+
# and potentially use the existing model's predictions as additional features
|
| 276 |
+
|
| 277 |
+
# Train the model
|
| 278 |
+
model.fit(X_train, y_train,
|
| 279 |
+
eval_set=[(X_test, y_test)],
|
| 280 |
+
early_stopping_rounds=10,
|
| 281 |
+
verbose=False)
|
| 282 |
+
|
| 283 |
+
# Evaluate
|
| 284 |
+
y_pred = model.predict(X_test)
|
| 285 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 286 |
+
|
| 287 |
+
logger.info(f"✅ Model trained with accuracy: {accuracy:.4f}")
|
| 288 |
+
|
| 289 |
+
return model, label_encoder, scaler
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
logger.error(f"❌ Model training failed: {e}")
|
| 293 |
+
raise
|
| 294 |
+
|
| 295 |
+
def _evaluate_model(self, model: Any, label_encoder: Any, scaler: Any,
|
| 296 |
+
X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
|
| 297 |
+
"""Evaluate model performance"""
|
| 298 |
+
try:
|
| 299 |
+
# Prepare test data
|
| 300 |
+
X_scaled = scaler.transform(X)
|
| 301 |
+
y_encoded = label_encoder.transform(y)
|
| 302 |
+
|
| 303 |
+
# Make predictions
|
| 304 |
+
y_pred = model.predict(X_scaled)
|
| 305 |
+
|
| 306 |
+
# Calculate metrics
|
| 307 |
+
accuracy = accuracy_score(y_encoded, y_pred)
|
| 308 |
+
|
| 309 |
+
# Cross-validation score
|
| 310 |
+
cv_scores = cross_val_score(model, X_scaled, y_encoded, cv=5)
|
| 311 |
+
cv_mean = cv_scores.mean()
|
| 312 |
+
cv_std = cv_scores.std()
|
| 313 |
+
|
| 314 |
+
metrics = {
|
| 315 |
+
"accuracy": accuracy,
|
| 316 |
+
"cv_mean": cv_mean,
|
| 317 |
+
"cv_std": cv_std,
|
| 318 |
+
"cv_scores": cv_scores.tolist()
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
logger.info(f"📊 Model evaluation: accuracy={accuracy:.4f}, cv_mean={cv_mean:.4f}±{cv_std:.4f}")
|
| 322 |
+
return metrics
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
logger.error(f"❌ Model evaluation failed: {e}")
|
| 326 |
+
return {"accuracy": 0.0, "cv_mean": 0.0, "cv_std": 0.0}
|
| 327 |
+
|
| 328 |
+
def train(self, max_datasets: int = 10) -> Dict[str, Any]:
|
| 329 |
+
"""Main training pipeline"""
|
| 330 |
+
try:
|
| 331 |
+
logger.info("🚀 Starting RLHF training pipeline")
|
| 332 |
+
|
| 333 |
+
# Load new labeled datasets with original data for RLHF
|
| 334 |
+
training_data, dataset_names = self.loader.create_rlhf_training_batch(max_datasets=max_datasets)
|
| 335 |
+
|
| 336 |
+
if not training_data:
|
| 337 |
+
logger.warning("⚠️ No new datasets available for RLHF training")
|
| 338 |
+
return {"status": "no_data", "message": "No new datasets available"}
|
| 339 |
+
|
| 340 |
+
logger.info(f"📦 Loaded {len(training_data)} datasets for RLHF training")
|
| 341 |
+
|
| 342 |
+
# Create RLHF dataset
|
| 343 |
+
X, y, rlhf_metadata = self._create_rlhf_dataset(training_data)
|
| 344 |
+
|
| 345 |
+
# Load existing model for comparison
|
| 346 |
+
existing_model, existing_le, existing_scaler, expected_features = self._load_existing_model()
|
| 347 |
+
|
| 348 |
+
# Train new model
|
| 349 |
+
model, label_encoder, scaler = self._train_model(X, y, existing_model)
|
| 350 |
+
|
| 351 |
+
# Evaluate model
|
| 352 |
+
metrics = self._evaluate_model(model, label_encoder, scaler, X, y)
|
| 353 |
+
|
| 354 |
+
# Generate model version using semantic versioning
|
| 355 |
+
model_version = self.saver._get_next_version()
|
| 356 |
+
|
| 357 |
+
# Prepare training data info
|
| 358 |
+
training_data_info = {
|
| 359 |
+
"datasets": dataset_names,
|
| 360 |
+
"total_samples": len(X),
|
| 361 |
+
"training_date": datetime.now().isoformat(),
|
| 362 |
+
"features_count": X.shape[1]
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
# Prepare training log
|
| 366 |
+
training_log = {
|
| 367 |
+
"datasets_used": dataset_names,
|
| 368 |
+
"samples_processed": len(X),
|
| 369 |
+
"model_parameters": self.model_params,
|
| 370 |
+
"performance_metrics": metrics,
|
| 371 |
+
"training_duration": "N/A", # Could be tracked if needed
|
| 372 |
+
"existing_model_used": existing_model is not None
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
# Save model
|
| 376 |
+
save_result = self.saver.save_complete_model(
|
| 377 |
+
model=model,
|
| 378 |
+
label_encoder=label_encoder,
|
| 379 |
+
scaler=scaler,
|
| 380 |
+
model_version=model_version,
|
| 381 |
+
training_data_info=training_data_info,
|
| 382 |
+
performance_metrics=metrics,
|
| 383 |
+
training_log=training_log,
|
| 384 |
+
rlhf_metadata=rlhf_metadata
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
result = {
|
| 388 |
+
"status": "success",
|
| 389 |
+
"model_version": model_version,
|
| 390 |
+
"datasets_processed": len(dataset_names),
|
| 391 |
+
"samples_processed": len(X),
|
| 392 |
+
"performance_metrics": metrics,
|
| 393 |
+
"save_result": save_result,
|
| 394 |
+
"training_log": training_log
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
logger.info(f"✅ RLHF training completed successfully: v{model_version}")
|
| 398 |
+
return result
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.error(f"❌ RLHF training failed: {e}")
|
| 402 |
+
return {
|
| 403 |
+
"status": "error",
|
| 404 |
+
"error": str(e),
|
| 405 |
+
"timestamp": datetime.now().isoformat()
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def main():
|
| 410 |
+
"""Test the RLHF trainer"""
|
| 411 |
+
try:
|
| 412 |
+
trainer = RLHFTrainer()
|
| 413 |
+
result = trainer.train(max_datasets=5)
|
| 414 |
+
print(f"Training result: {result}")
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"Training failed: {e}")
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
if __name__ == "__main__":
|
| 420 |
+
main()
|
train/saver.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# saver.py
|
| 2 |
+
# Model saving functions for RLHF training
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import pickle
|
| 7 |
+
import joblib
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import HfApi, Repository
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("rlhf-saver")
|
| 17 |
+
logger.setLevel(logging.INFO)
|
| 18 |
+
if not logger.handlers:
|
| 19 |
+
_h = logging.StreamHandler()
|
| 20 |
+
_h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
|
| 21 |
+
logger.addHandler(_h)
|
| 22 |
+
|
| 23 |
+
class ModelSaver:
|
| 24 |
+
"""
|
| 25 |
+
Save trained models to Hugging Face Hub and local storage.
|
| 26 |
+
Handles model artifacts, metadata, and versioning.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self):
|
| 30 |
+
self.hf_token = os.getenv("HF_TOKEN")
|
| 31 |
+
if not self.hf_token:
|
| 32 |
+
raise RuntimeError("HF_TOKEN environment variable not set")
|
| 33 |
+
|
| 34 |
+
self.hf_api = HfApi(token=self.hf_token)
|
| 35 |
+
self.repo_id = os.getenv("HF_MODEL_REPO", "BinKhoaLe1812/Driver_Behavior_OBD")
|
| 36 |
+
|
| 37 |
+
# Local model directory
|
| 38 |
+
self.local_model_dir = Path(os.getenv("MODEL_DIR", "/app/models/ul"))
|
| 39 |
+
self.local_model_dir.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
logger.info(f"📦 ModelSaver ready | repo={self.repo_id}")
|
| 42 |
+
|
| 43 |
+
def _get_next_version(self) -> str:
|
| 44 |
+
"""Get the next version number (1.0, 1.1, 1.2, ..., 1.9, 2.0, etc.)"""
|
| 45 |
+
try:
|
| 46 |
+
# List existing versions in HF repo
|
| 47 |
+
repo_files = self.hf_api.list_repo_files(
|
| 48 |
+
repo_id=self.repo_id,
|
| 49 |
+
repo_type="model"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Find version directories (v1.0, v1.1, etc.)
|
| 53 |
+
version_dirs = [f for f in repo_files if f.startswith('v') and '/' not in f]
|
| 54 |
+
versions = []
|
| 55 |
+
|
| 56 |
+
for v_dir in version_dirs:
|
| 57 |
+
try:
|
| 58 |
+
version_str = v_dir[1:] # Remove 'v' prefix
|
| 59 |
+
if '.' in version_str:
|
| 60 |
+
major, minor = version_str.split('.')
|
| 61 |
+
versions.append((int(major), int(minor)))
|
| 62 |
+
except (ValueError, IndexError):
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
if not versions:
|
| 66 |
+
return "1.0"
|
| 67 |
+
|
| 68 |
+
# Sort versions and get the latest
|
| 69 |
+
versions.sort()
|
| 70 |
+
latest_major, latest_minor = versions[-1]
|
| 71 |
+
|
| 72 |
+
# Increment version
|
| 73 |
+
if latest_minor < 9:
|
| 74 |
+
return f"{latest_major}.{latest_minor + 1}"
|
| 75 |
+
else:
|
| 76 |
+
return f"{latest_major + 1}.0"
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.warning(f"⚠️ Failed to get next version from HF repo: {e}")
|
| 80 |
+
# Fallback to timestamp-based version
|
| 81 |
+
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 82 |
+
|
| 83 |
+
def _create_model_metadata(self,
|
| 84 |
+
model_type: str,
|
| 85 |
+
training_data_info: Dict[str, Any],
|
| 86 |
+
performance_metrics: Dict[str, float],
|
| 87 |
+
model_version: str,
|
| 88 |
+
rlhf_metadata: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 89 |
+
"""Create metadata for the trained model"""
|
| 90 |
+
metadata = {
|
| 91 |
+
"model_type": model_type,
|
| 92 |
+
"version": model_version,
|
| 93 |
+
"created_at": datetime.now().isoformat(),
|
| 94 |
+
"training_data": training_data_info,
|
| 95 |
+
"performance_metrics": performance_metrics,
|
| 96 |
+
"framework": "xgboost",
|
| 97 |
+
"task": "driver_behavior_classification",
|
| 98 |
+
"labels": ["aggressive", "normal", "conservative"], # Based on ul_label.py
|
| 99 |
+
"features": "obd_sensor_data",
|
| 100 |
+
"rlhf_metadata": rlhf_metadata or {}
|
| 101 |
+
}
|
| 102 |
+
return metadata
|
| 103 |
+
|
| 104 |
+
def save_model_locally(self,
|
| 105 |
+
model: Any,
|
| 106 |
+
label_encoder: Any,
|
| 107 |
+
scaler: Any,
|
| 108 |
+
model_version: str,
|
| 109 |
+
metadata: Dict[str, Any]) -> Dict[str, str]:
|
| 110 |
+
"""Save model components locally with versioning"""
|
| 111 |
+
try:
|
| 112 |
+
# Create versioned directory
|
| 113 |
+
version_dir = self.local_model_dir / f"v{model_version}"
|
| 114 |
+
version_dir.mkdir(exist_ok=True)
|
| 115 |
+
|
| 116 |
+
# Save model components
|
| 117 |
+
model_path = version_dir / "xgb_drivestyle_ul.pkl"
|
| 118 |
+
le_path = version_dir / "label_encoder_ul.pkl"
|
| 119 |
+
scaler_path = version_dir / "scaler_ul.pkl"
|
| 120 |
+
metadata_path = version_dir / "metadata.json"
|
| 121 |
+
|
| 122 |
+
# Save using joblib for better compatibility
|
| 123 |
+
joblib.dump(model, model_path)
|
| 124 |
+
joblib.dump(label_encoder, le_path)
|
| 125 |
+
joblib.dump(scaler, scaler_path)
|
| 126 |
+
|
| 127 |
+
# Save metadata
|
| 128 |
+
with open(metadata_path, 'w') as f:
|
| 129 |
+
json.dump(metadata, f, indent=2)
|
| 130 |
+
|
| 131 |
+
# Also save to the main model directory (for current usage)
|
| 132 |
+
joblib.dump(model, self.local_model_dir / "xgb_drivestyle_ul.pkl")
|
| 133 |
+
joblib.dump(label_encoder, self.local_model_dir / "label_encoder_ul.pkl")
|
| 134 |
+
joblib.dump(scaler, self.local_model_dir / "scaler_ul.pkl")
|
| 135 |
+
|
| 136 |
+
logger.info(f"✅ Model saved locally to {version_dir}")
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"model_path": str(model_path),
|
| 140 |
+
"label_encoder_path": str(le_path),
|
| 141 |
+
"scaler_path": str(scaler_path),
|
| 142 |
+
"metadata_path": str(metadata_path)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"❌ Failed to save model locally: {e}")
|
| 147 |
+
raise
|
| 148 |
+
|
| 149 |
+
def save_model_to_hf(self,
|
| 150 |
+
model: Any,
|
| 151 |
+
label_encoder: Any,
|
| 152 |
+
scaler: Any,
|
| 153 |
+
model_version: str,
|
| 154 |
+
metadata: Dict[str, Any],
|
| 155 |
+
training_data_info: Dict[str, Any]) -> str:
|
| 156 |
+
"""Save model to Hugging Face Hub"""
|
| 157 |
+
try:
|
| 158 |
+
# Create temporary directory for upload
|
| 159 |
+
temp_dir = Path(f"/tmp/hf_upload_{model_version}")
|
| 160 |
+
temp_dir.mkdir(exist_ok=True)
|
| 161 |
+
|
| 162 |
+
# Save model components
|
| 163 |
+
model_path = temp_dir / "xgb_drivestyle_ul.pkl"
|
| 164 |
+
le_path = temp_dir / "label_encoder_ul.pkl"
|
| 165 |
+
scaler_path = temp_dir / "scaler_ul.pkl"
|
| 166 |
+
metadata_path = temp_dir / "metadata.json"
|
| 167 |
+
readme_path = temp_dir / "README.md"
|
| 168 |
+
|
| 169 |
+
# Save using joblib
|
| 170 |
+
joblib.dump(model, model_path)
|
| 171 |
+
joblib.dump(label_encoder, le_path)
|
| 172 |
+
joblib.dump(scaler, scaler_path)
|
| 173 |
+
|
| 174 |
+
# Save metadata
|
| 175 |
+
with open(metadata_path, 'w') as f:
|
| 176 |
+
json.dump(metadata, f, indent=2)
|
| 177 |
+
|
| 178 |
+
# Create README
|
| 179 |
+
readme_content = self._create_readme(metadata, training_data_info)
|
| 180 |
+
with open(readme_path, 'w') as f:
|
| 181 |
+
f.write(readme_content)
|
| 182 |
+
|
| 183 |
+
# Upload to Hugging Face Hub
|
| 184 |
+
self.hf_api.upload_folder(
|
| 185 |
+
folder_path=str(temp_dir),
|
| 186 |
+
repo_id=self.repo_id,
|
| 187 |
+
repo_type="model",
|
| 188 |
+
commit_message=f"RLHF training update v{model_version}",
|
| 189 |
+
ignore_patterns=["*.tmp", "*.log"]
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Clean up temp directory
|
| 193 |
+
import shutil
|
| 194 |
+
shutil.rmtree(temp_dir)
|
| 195 |
+
|
| 196 |
+
logger.info(f"✅ Model uploaded to Hugging Face Hub: {self.repo_id}")
|
| 197 |
+
return f"https://huggingface.co/{self.repo_id}"
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"❌ Failed to save model to HF: {e}")
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
def _create_readme(self, metadata: Dict[str, Any], training_data_info: Dict[str, Any]) -> str:
|
| 204 |
+
"""Create README content for the model"""
|
| 205 |
+
readme = f"""---
|
| 206 |
+
license: mit
|
| 207 |
+
tags:
|
| 208 |
+
- driver-behavior
|
| 209 |
+
- obd-data
|
| 210 |
+
- xgboost
|
| 211 |
+
- rlhf
|
| 212 |
+
- reinforcement-learning
|
| 213 |
+
---
|
| 214 |
+
|
| 215 |
+
# Driver Behavior Classification Model (RLHF v{metadata['version']})
|
| 216 |
+
|
| 217 |
+
This model classifies driver behavior based on OBD (On-Board Diagnostics) sensor data using XGBoost.
|
| 218 |
+
|
| 219 |
+
## Model Information
|
| 220 |
+
|
| 221 |
+
- **Model Type**: {metadata['model_type']}
|
| 222 |
+
- **Version**: {metadata['version']}
|
| 223 |
+
- **Created**: {metadata['created_at']}
|
| 224 |
+
- **Framework**: {metadata['framework']}
|
| 225 |
+
- **Task**: {metadata['task']}
|
| 226 |
+
|
| 227 |
+
## Performance Metrics
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
for metric, value in metadata['performance_metrics'].items():
|
| 232 |
+
readme += f"- **{metric}**: {value:.4f}\n"
|
| 233 |
+
|
| 234 |
+
readme += f"""
|
| 235 |
+
## Training Data
|
| 236 |
+
|
| 237 |
+
- **Datasets Used**: {len(training_data_info.get('datasets', []))}
|
| 238 |
+
- **Total Samples**: {training_data_info.get('total_samples', 'N/A')}
|
| 239 |
+
- **Training Date**: {training_data_info.get('training_date', 'N/A')}
|
| 240 |
+
|
| 241 |
+
## Labels
|
| 242 |
+
|
| 243 |
+
The model predicts one of the following driver behavior categories:
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
for label in metadata['labels']:
|
| 247 |
+
readme += f"- {label}\n"
|
| 248 |
+
|
| 249 |
+
readme += """
|
| 250 |
+
## Usage
|
| 251 |
+
|
| 252 |
+
```python
|
| 253 |
+
import joblib
|
| 254 |
+
import pandas as pd
|
| 255 |
+
|
| 256 |
+
# Load the model
|
| 257 |
+
model = joblib.load('xgb_drivestyle_ul.pkl')
|
| 258 |
+
label_encoder = joblib.load('label_encoder_ul.pkl')
|
| 259 |
+
scaler = joblib.load('scaler_ul.pkl')
|
| 260 |
+
|
| 261 |
+
# Prepare your OBD data
|
| 262 |
+
# (Ensure features match the training data format)
|
| 263 |
+
|
| 264 |
+
# Make predictions
|
| 265 |
+
predictions = model.predict(scaled_data)
|
| 266 |
+
behavior_labels = label_encoder.inverse_transform(predictions)
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
## Files
|
| 270 |
+
|
| 271 |
+
- `xgb_drivestyle_ul.pkl`: Main XGBoost model
|
| 272 |
+
- `label_encoder_ul.pkl`: Label encoder for behavior categories
|
| 273 |
+
- `scaler_ul.pkl`: Feature scaler
|
| 274 |
+
- `metadata.json`: Model metadata and performance metrics
|
| 275 |
+
|
| 276 |
+
## RLHF Training
|
| 277 |
+
|
| 278 |
+
This model was trained using Reinforcement Learning from Human Feedback (RLHF) to improve performance based on human-labeled data and feedback.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
return readme
|
| 282 |
+
|
| 283 |
+
def save_training_log(self,
|
| 284 |
+
training_log: Dict[str, Any],
|
| 285 |
+
model_version: str) -> str:
|
| 286 |
+
"""Save training log to Firebase storage"""
|
| 287 |
+
try:
|
| 288 |
+
# Import Firebase client
|
| 289 |
+
import sys
|
| 290 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 291 |
+
from data.firebase_saver import FirebaseSaver
|
| 292 |
+
|
| 293 |
+
# Create log entry
|
| 294 |
+
log_entry = {
|
| 295 |
+
"version": model_version,
|
| 296 |
+
"timestamp": datetime.now().isoformat(),
|
| 297 |
+
"log": training_log
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
# Save to Firebase
|
| 301 |
+
saver = FirebaseSaver()
|
| 302 |
+
# Note: We'll need to modify FirebaseSaver to support different prefixes
|
| 303 |
+
# For now, we'll save to a logs subdirectory
|
| 304 |
+
log_filename = f"training_log_v{model_version}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 305 |
+
|
| 306 |
+
# Create temporary file
|
| 307 |
+
temp_path = f"/tmp/{log_filename}"
|
| 308 |
+
with open(temp_path, 'w') as f:
|
| 309 |
+
json.dump(log_entry, f, indent=2)
|
| 310 |
+
|
| 311 |
+
# Upload to Firebase (we'll need to extend FirebaseSaver for this)
|
| 312 |
+
# For now, just log locally
|
| 313 |
+
logger.info(f"📝 Training log saved: {log_entry}")
|
| 314 |
+
|
| 315 |
+
return temp_path
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.error(f"❌ Failed to save training log: {e}")
|
| 319 |
+
return ""
|
| 320 |
+
|
| 321 |
+
def save_complete_model(self,
|
| 322 |
+
model: Any,
|
| 323 |
+
label_encoder: Any,
|
| 324 |
+
scaler: Any,
|
| 325 |
+
model_version: str,
|
| 326 |
+
training_data_info: Dict[str, Any],
|
| 327 |
+
performance_metrics: Dict[str, float],
|
| 328 |
+
training_log: Dict[str, Any],
|
| 329 |
+
rlhf_metadata: Dict[str, Any] = None) -> Dict[str, str]:
|
| 330 |
+
"""Complete model saving workflow"""
|
| 331 |
+
try:
|
| 332 |
+
# Create metadata
|
| 333 |
+
metadata = self._create_model_metadata(
|
| 334 |
+
model_type="xgboost_classifier",
|
| 335 |
+
training_data_info=training_data_info,
|
| 336 |
+
performance_metrics=performance_metrics,
|
| 337 |
+
model_version=model_version,
|
| 338 |
+
rlhf_metadata=rlhf_metadata
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Save locally
|
| 342 |
+
local_paths = self.save_model_locally(
|
| 343 |
+
model, label_encoder, scaler, model_version, metadata
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# Save to Hugging Face Hub
|
| 347 |
+
hf_url = self.save_model_to_hf(
|
| 348 |
+
model, label_encoder, scaler, model_version, metadata, training_data_info
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Save training log
|
| 352 |
+
log_path = self.save_training_log(training_log, model_version)
|
| 353 |
+
|
| 354 |
+
result = {
|
| 355 |
+
"local_paths": local_paths,
|
| 356 |
+
"hf_url": hf_url,
|
| 357 |
+
"log_path": log_path,
|
| 358 |
+
"version": model_version,
|
| 359 |
+
"metadata": metadata
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
logger.info(f"✅ Complete model save successful: v{model_version}")
|
| 363 |
+
return result
|
| 364 |
+
|
| 365 |
+
except Exception as e:
|
| 366 |
+
logger.error(f"❌ Complete model save failed: {e}")
|
| 367 |
+
raise
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def main():
|
| 371 |
+
"""Test the saver functionality"""
|
| 372 |
+
try:
|
| 373 |
+
saver = ModelSaver()
|
| 374 |
+
print(f"ModelSaver initialized for repo: {saver.repo_id}")
|
| 375 |
+
print(f"Local model directory: {saver.local_model_dir}")
|
| 376 |
+
except Exception as e:
|
| 377 |
+
print(f"Failed to initialize ModelSaver: {e}")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
if __name__ == "__main__":
|
| 381 |
+
main()
|
utils/download.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# download.py
|
| 2 |
+
# Download latest models from Hugging Face
|
| 3 |
+
import os, shutil, pathlib, sys
|
| 4 |
+
import json
|
| 5 |
+
from huggingface_hub import hf_hub_download, HfApi
|
| 6 |
+
|
| 7 |
+
def load_env_file():
|
| 8 |
+
"""Load environment variables from .env file if it exists"""
|
| 9 |
+
env_path = pathlib.Path(__file__).parent.parent / ".env"
|
| 10 |
+
if env_path.exists():
|
| 11 |
+
with open(env_path, 'r') as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
line = line.strip()
|
| 14 |
+
if line and not line.startswith('#') and '=' in line:
|
| 15 |
+
key, value = line.split('=', 1)
|
| 16 |
+
os.environ[key] = value
|
| 17 |
+
return True
|
| 18 |
+
return False
|
| 19 |
+
|
| 20 |
+
# Load .env file first before setting any environment variables
|
| 21 |
+
load_env_file()
|
| 22 |
+
|
| 23 |
+
REPO_ID = os.getenv("HF_MODEL_REPO", "BinKhoaLe1812/Driver_Behavior_OBD")
|
| 24 |
+
MODEL_DIR = pathlib.Path(os.getenv("MODEL_DIR", "/app/models/ul")).resolve()
|
| 25 |
+
FILES = ["label_encoder_ul.pkl", "scaler_ul.pkl", "xgb_drivestyle_ul.pkl"]
|
| 26 |
+
|
| 27 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
def get_latest_version():
|
| 30 |
+
"""Get the latest model version from Hugging Face repo"""
|
| 31 |
+
try:
|
| 32 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 33 |
+
if not hf_token:
|
| 34 |
+
print("⚠️ HF_TOKEN not set, using default model files")
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
api = HfApi(token=hf_token)
|
| 38 |
+
repo_files = api.list_repo_files(
|
| 39 |
+
repo_id=REPO_ID,
|
| 40 |
+
repo_type="model"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(f"🔍 Checking repository files...")
|
| 44 |
+
print(f"📁 Found {len(repo_files)} files in repository")
|
| 45 |
+
|
| 46 |
+
# Find version directories (v1.0, v1.1, etc.)
|
| 47 |
+
version_dirs = [f for f in repo_files if f.startswith('v') and '/' not in f]
|
| 48 |
+
print(f"📦 Found version directories: {version_dirs}")
|
| 49 |
+
|
| 50 |
+
# Also check for version directories with files inside
|
| 51 |
+
version_dirs_with_files = []
|
| 52 |
+
for f in repo_files:
|
| 53 |
+
if f.startswith('v') and '/' in f:
|
| 54 |
+
version_dir = f.split('/')[0]
|
| 55 |
+
if version_dir not in version_dirs_with_files:
|
| 56 |
+
version_dirs_with_files.append(version_dir)
|
| 57 |
+
|
| 58 |
+
if version_dirs_with_files:
|
| 59 |
+
print(f"📦 Found version directories with files: {version_dirs_with_files}")
|
| 60 |
+
version_dirs.extend(version_dirs_with_files)
|
| 61 |
+
|
| 62 |
+
versions = []
|
| 63 |
+
|
| 64 |
+
for v_dir in version_dirs:
|
| 65 |
+
try:
|
| 66 |
+
version_str = v_dir[1:] # Remove 'v' prefix
|
| 67 |
+
if '.' in version_str:
|
| 68 |
+
major, minor = version_str.split('.')
|
| 69 |
+
versions.append((int(major), int(minor), v_dir))
|
| 70 |
+
print(f"✅ Found version: {v_dir} (major={major}, minor={minor})")
|
| 71 |
+
except (ValueError, IndexError):
|
| 72 |
+
print(f"⚠️ Could not parse version: {v_dir}")
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
if not versions:
|
| 76 |
+
print("📦 No versioned models found, checking for root files...")
|
| 77 |
+
# Check if files exist in root
|
| 78 |
+
root_files = [f for f in repo_files if f in FILES]
|
| 79 |
+
if root_files:
|
| 80 |
+
print(f"📁 Found root files: {root_files}")
|
| 81 |
+
return None # Use root files
|
| 82 |
+
else:
|
| 83 |
+
print("❌ No model files found in repository")
|
| 84 |
+
print("💡 Available files in repository:")
|
| 85 |
+
for f in sorted(repo_files):
|
| 86 |
+
print(f" - {f}")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
# Sort versions and get the latest
|
| 90 |
+
versions.sort()
|
| 91 |
+
latest_version = versions[-1][2] # Get the directory name
|
| 92 |
+
print(f"📦 Latest model version: {latest_version}")
|
| 93 |
+
return latest_version
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"⚠️ Failed to get latest version: {e}")
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
def fetch_latest(fname: str, version_dir: str = None):
|
| 100 |
+
"""Download the latest version of a model file"""
|
| 101 |
+
try:
|
| 102 |
+
if version_dir:
|
| 103 |
+
# Download from versioned directory
|
| 104 |
+
versioned_path = f"{version_dir}/{fname}"
|
| 105 |
+
print(f"📥 Downloading {fname} from {versioned_path}...")
|
| 106 |
+
src = hf_hub_download(repo_id=REPO_ID, filename=versioned_path, repo_type="model")
|
| 107 |
+
else:
|
| 108 |
+
# Download from root directory (fallback)
|
| 109 |
+
print(f"📥 Downloading {fname} from root directory...")
|
| 110 |
+
src = hf_hub_download(repo_id=REPO_ID, filename=fname, repo_type="model")
|
| 111 |
+
|
| 112 |
+
dst = MODEL_DIR / fname
|
| 113 |
+
shutil.copy2(src, dst)
|
| 114 |
+
print(f"✅ Downloaded {fname} → {dst}")
|
| 115 |
+
return True
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"❌ Failed to fetch {fname}: {e}")
|
| 118 |
+
if version_dir:
|
| 119 |
+
print(f" Tried path: {version_dir}/{fname}")
|
| 120 |
+
else:
|
| 121 |
+
print(f" Tried path: {fname}")
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def download_latest_models():
|
| 125 |
+
"""Download the latest version of all model files"""
|
| 126 |
+
print("🔄 Checking for latest model version...")
|
| 127 |
+
latest_version = get_latest_version()
|
| 128 |
+
|
| 129 |
+
success_count = 0
|
| 130 |
+
for f in FILES:
|
| 131 |
+
if fetch_latest(f, latest_version):
|
| 132 |
+
success_count += 1
|
| 133 |
+
|
| 134 |
+
if success_count == len(FILES):
|
| 135 |
+
print(f"✅ Successfully downloaded all {len(FILES)} model files")
|
| 136 |
+
if latest_version:
|
| 137 |
+
print(f"📦 Using version: {latest_version}")
|
| 138 |
+
return True
|
| 139 |
+
else:
|
| 140 |
+
print(f"⚠️ Only {success_count}/{len(FILES)} files downloaded successfully")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
def fetch(fname: str):
|
| 144 |
+
"""Legacy function for backward compatibility"""
|
| 145 |
+
return fetch_latest(fname)
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
"""Download latest models"""
|
| 149 |
+
success = download_latest_models()
|
| 150 |
+
if not success:
|
| 151 |
+
sys.exit(1)
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
main()
|
utils/mount_drive.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import gspread
|
| 4 |
+
import logging
|
| 5 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
| 6 |
+
|
| 7 |
+
# Setup logging
|
| 8 |
+
logger = logging.getLogger("upload")
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(asctime)s - %(message)s")
|
| 10 |
+
|
| 11 |
+
# Authenticate with GDrive using secret
|
| 12 |
+
logger.info("Authenticating to Google Drive...")
|
| 13 |
+
creds_json = os.getenv("GDRIVE_CREDENTIALS_JSON")
|
| 14 |
+
if not creds_json:
|
| 15 |
+
logger.error("GDRIVE_CREDENTIALS_JSON not found!")
|
| 16 |
+
exit(1)
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
creds_dict = json.loads(creds_json)
|
| 20 |
+
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
|
| 21 |
+
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
|
| 22 |
+
client = gspread.authorize(creds)
|
| 23 |
+
logger.info("Authenticated with Google Drive")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logger.error(f"Failed to authenticate: {e}")
|
| 26 |
+
exit(1)
|
| 27 |
+
|
| 28 |
+
# Folder and files
|
| 29 |
+
upload_dir = "./cache/obd_data/cleaned"
|
| 30 |
+
if not os.path.exists(upload_dir):
|
| 31 |
+
logger.warning(f"Directory {upload_dir} does not exist.")
|
| 32 |
+
exit(0)
|
| 33 |
+
|
| 34 |
+
# Upload all .csv files
|
| 35 |
+
for file in os.listdir(upload_dir):
|
| 36 |
+
if file.endswith(".csv"):
|
| 37 |
+
try:
|
| 38 |
+
path = os.path.join(upload_dir, file)
|
| 39 |
+
logger.info(f"Uploading {file}...")
|
| 40 |
+
with open(path, "rb") as f:
|
| 41 |
+
client.import_csv(client.create(file).id, f.read())
|
| 42 |
+
logger.info(f"Uploaded {file}")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Failed to upload {file}: {e}")
|
utils/ul_label.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ul_label.py
|
| 2 |
+
# Load UL models and predict driving style
|
| 3 |
+
import os, logging, pickle
|
| 4 |
+
import warnings
|
| 5 |
+
import joblib
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
# Import download functionality
|
| 10 |
+
import sys
|
| 11 |
+
sys.path.append(os.path.dirname(__file__))
|
| 12 |
+
from download import download_latest_models
|
| 13 |
+
|
| 14 |
+
log = logging.getLogger("ul-labeler")
|
| 15 |
+
log.setLevel(logging.INFO)
|
| 16 |
+
|
| 17 |
+
# Suppress version compatibility warnings in production
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.base")
|
| 19 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost.core")
|
| 20 |
+
|
| 21 |
+
MODEL_DIR = os.getenv("MODEL_DIR", "/app/models/ul")
|
| 22 |
+
LE_PATH = os.path.join(MODEL_DIR, "label_encoder_ul.pkl")
|
| 23 |
+
SC_PATH = os.path.join(MODEL_DIR, "scaler_ul.pkl")
|
| 24 |
+
XGB_PATH = os.path.join(MODEL_DIR, "xgb_drivestyle_ul.pkl")
|
| 25 |
+
|
| 26 |
+
SAFE_DROP = {
|
| 27 |
+
"timestamp","driving_style","ul_drivestyle","gt_drivestyle",
|
| 28 |
+
"session_id","imported_at","record_index"
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
def _load_any(path):
|
| 32 |
+
# Suppress version compatibility warnings for production
|
| 33 |
+
with warnings.catch_warnings():
|
| 34 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
|
| 35 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
|
| 36 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
|
| 37 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="xgboost")
|
| 38 |
+
try:
|
| 39 |
+
model = joblib.load(path)
|
| 40 |
+
except Exception:
|
| 41 |
+
with open(path, "rb") as f:
|
| 42 |
+
model = pickle.load(f)
|
| 43 |
+
|
| 44 |
+
# Fix XGBoost compatibility issues for older trained models
|
| 45 |
+
if hasattr(model, 'get_booster'): # This is an XGBoost model
|
| 46 |
+
# Remove deprecated use_label_encoder attribute that causes issues in newer XGBoost versions
|
| 47 |
+
if hasattr(model, '__dict__'):
|
| 48 |
+
# Remove all deprecated attributes that cause issues
|
| 49 |
+
deprecated_attrs = [
|
| 50 |
+
'use_label_encoder', '_le', '_label_encoder',
|
| 51 |
+
'use_label_encoder_', '_le_', '_label_encoder_'
|
| 52 |
+
]
|
| 53 |
+
for attr in deprecated_attrs:
|
| 54 |
+
model.__dict__.pop(attr, None)
|
| 55 |
+
|
| 56 |
+
# Set use_label_encoder to False for newer XGBoost versions
|
| 57 |
+
if hasattr(model, 'set_params'):
|
| 58 |
+
try:
|
| 59 |
+
model.set_params(use_label_encoder=False)
|
| 60 |
+
except Exception:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
class ULLabeler:
|
| 66 |
+
_instance = None
|
| 67 |
+
|
| 68 |
+
def __init__(self, auto_download: bool = True):
|
| 69 |
+
# Auto-download latest models if enabled
|
| 70 |
+
if auto_download:
|
| 71 |
+
log.info("🔄 Checking for latest model version...")
|
| 72 |
+
try:
|
| 73 |
+
download_latest_models()
|
| 74 |
+
except Exception as e:
|
| 75 |
+
log.warning(f"⚠️ Failed to download latest models: {e}")
|
| 76 |
+
|
| 77 |
+
if not (os.path.exists(LE_PATH) and os.path.exists(SC_PATH) and os.path.exists(XGB_PATH)):
|
| 78 |
+
raise FileNotFoundError("Model files not found. Ensure download.py ran successfully.")
|
| 79 |
+
self.le = _load_any(LE_PATH)
|
| 80 |
+
self.scal = _load_any(SC_PATH)
|
| 81 |
+
self.clf = _load_any(XGB_PATH)
|
| 82 |
+
|
| 83 |
+
# Additional XGBoost compatibility fixes
|
| 84 |
+
self._fix_xgb_compatibility()
|
| 85 |
+
|
| 86 |
+
# Try to discover expected feature names from scaler or model
|
| 87 |
+
self.expected = None
|
| 88 |
+
if hasattr(self.scal, "feature_names_in_"):
|
| 89 |
+
self.expected = list(self.scal.feature_names_in_)
|
| 90 |
+
elif hasattr(self.clf, "feature_names_in_"):
|
| 91 |
+
self.expected = list(self.clf.feature_names_in_)
|
| 92 |
+
|
| 93 |
+
log.info(f"ULLabeler ready | expected_features={len(self.expected) if self.expected else 'unknown'}")
|
| 94 |
+
|
| 95 |
+
def _fix_xgb_compatibility(self):
|
| 96 |
+
"""Fix XGBoost compatibility issues with older trained models."""
|
| 97 |
+
try:
|
| 98 |
+
# Check if this is an XGBoost classifier
|
| 99 |
+
if hasattr(self.clf, 'get_booster'):
|
| 100 |
+
# Remove deprecated attributes that cause issues in newer XGBoost versions
|
| 101 |
+
deprecated_attrs = [
|
| 102 |
+
'use_label_encoder', '_le', '_label_encoder',
|
| 103 |
+
'use_label_encoder_', '_le_', '_label_encoder_'
|
| 104 |
+
]
|
| 105 |
+
for attr in deprecated_attrs:
|
| 106 |
+
if hasattr(self.clf, attr):
|
| 107 |
+
try:
|
| 108 |
+
delattr(self.clf, attr)
|
| 109 |
+
except (AttributeError, TypeError):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
# Set use_label_encoder to False for newer XGBoost versions
|
| 113 |
+
if hasattr(self.clf, 'set_params'):
|
| 114 |
+
try:
|
| 115 |
+
self.clf.set_params(use_label_encoder=False)
|
| 116 |
+
except Exception:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
# Ensure the model is properly configured for prediction
|
| 120 |
+
if hasattr(self.clf, 'n_classes_') and self.clf.n_classes_ is None:
|
| 121 |
+
# Try to infer number of classes from the label encoder
|
| 122 |
+
if hasattr(self.le, 'classes_'):
|
| 123 |
+
self.clf.n_classes_ = len(self.le.classes_)
|
| 124 |
+
|
| 125 |
+
# For newer XGBoost versions, ensure the model is properly initialized
|
| 126 |
+
if hasattr(self.clf, '_le') and self.clf._le is None:
|
| 127 |
+
self.clf._le = None
|
| 128 |
+
|
| 129 |
+
log.info("XGBoost compatibility fixes applied successfully")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
log.warning(f"XGBoost compatibility fix failed: {e}")
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def get(cls, auto_download: bool = True):
|
| 135 |
+
if cls._instance is None:
|
| 136 |
+
cls._instance = ULLabeler(auto_download=auto_download)
|
| 137 |
+
return cls._instance
|
| 138 |
+
|
| 139 |
+
def _prepare(self, df: pd.DataFrame):
|
| 140 |
+
# numeric only + drop non-feature columns
|
| 141 |
+
cols = [c for c in df.columns if c not in SAFE_DROP and pd.api.types.is_numeric_dtype(df[c])]
|
| 142 |
+
X = df[cols].copy()
|
| 143 |
+
|
| 144 |
+
# ensure required features
|
| 145 |
+
if self.expected:
|
| 146 |
+
for c in self.expected:
|
| 147 |
+
if c not in X.columns:
|
| 148 |
+
X[c] = 0.0
|
| 149 |
+
X = X[self.expected] # align order
|
| 150 |
+
X = X.fillna(0)
|
| 151 |
+
|
| 152 |
+
# scale
|
| 153 |
+
try:
|
| 154 |
+
Xs = self.scal.transform(X if hasattr(self.scal, "feature_names_in_") else X.values)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
log.warning(f"Scaler transform failed ({e}); using raw features.")
|
| 157 |
+
Xs = X.values
|
| 158 |
+
return Xs
|
| 159 |
+
|
| 160 |
+
def predict_df(self, df: pd.DataFrame) -> np.ndarray:
|
| 161 |
+
Xs = self._prepare(df)
|
| 162 |
+
try:
|
| 163 |
+
yhat = self.clf.predict(Xs)
|
| 164 |
+
except (AttributeError, TypeError) as e:
|
| 165 |
+
if 'use_label_encoder' in str(e) or 'label_encoder' in str(e):
|
| 166 |
+
# Last resort: try to fix the model and retry
|
| 167 |
+
log.warning("XGBoost compatibility issue detected, attempting fix...")
|
| 168 |
+
try:
|
| 169 |
+
# Remove all problematic attributes
|
| 170 |
+
deprecated_attrs = [
|
| 171 |
+
'use_label_encoder', '_le', '_label_encoder',
|
| 172 |
+
'use_label_encoder_', '_le_', '_label_encoder_'
|
| 173 |
+
]
|
| 174 |
+
for attr in deprecated_attrs:
|
| 175 |
+
if hasattr(self.clf, attr):
|
| 176 |
+
try:
|
| 177 |
+
delattr(self.clf, attr)
|
| 178 |
+
except (AttributeError, TypeError):
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
# Set use_label_encoder to False
|
| 182 |
+
if hasattr(self.clf, 'set_params'):
|
| 183 |
+
try:
|
| 184 |
+
self.clf.set_params(use_label_encoder=False)
|
| 185 |
+
except Exception:
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
# Retry prediction
|
| 189 |
+
yhat = self.clf.predict(Xs)
|
| 190 |
+
except Exception as retry_e:
|
| 191 |
+
log.error(f"Failed to fix XGBoost compatibility: {retry_e}")
|
| 192 |
+
raise e
|
| 193 |
+
else:
|
| 194 |
+
raise e
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
return self.le.inverse_transform(yhat)
|
| 198 |
+
except Exception:
|
| 199 |
+
return yhat
|
| 200 |
+
|
| 201 |
+
def predict_csv(self, csv_path: str) -> pd.DataFrame:
|
| 202 |
+
df = pd.read_csv(csv_path)
|
| 203 |
+
y = self.predict_df(df)
|
| 204 |
+
out = df.copy()
|
| 205 |
+
out["driving_style"] = y
|
| 206 |
+
return out
|