LiamKhoaLe commited on
Commit
ff65bb4
·
0 Parent(s):

Initial commit: OBD Logger with RLHF training system

Browse files

- FastAPI-based OBD-II data processing
- Real-time data ingestion and cleaning
- Firebase and MongoDB integration
- RLHF training pipeline with versioned models
- Docker deployment ready
- Security: No hardcoded tokens

.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ token.json
3
+ service.json
4
+ firebase.json
5
+
6
+ # Security - prevent token leaks
7
+ *.token
8
+ *.key
9
+ *secret*
10
+ *credential*
11
+ *password*
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # ── Create and switch to non-root user ──
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+
7
+ # ── Set environment and working directory ──
8
+ ENV HOME=/home/user
9
+ WORKDIR $HOME/app
10
+
11
+ # ── Upgrade pip and install dependencies ──
12
+ COPY --chown=user requirements.txt .
13
+ RUN pip install --upgrade pip
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Install latest versions for UL model inference
17
+ RUN pip install --no-cache-dir huggingface_hub xgboost joblib scikit-learn
18
+
19
+ # ── Pre-mount GDrive (no-op if creds not found) ──
20
+ COPY --chown=user utils/mount_drive.py .
21
+ RUN python mount_drive.py || true
22
+
23
+ # ── Copy application source ──
24
+ COPY --chown=user . .
25
+
26
+ # ── Create required folders ──
27
+ RUN mkdir -p $HOME/app/logs \
28
+ $HOME/app/cache \
29
+ $HOME/app/cache/obd_data \
30
+ $HOME/app/cache/obd_data/plots \
31
+ $HOME/app/models/ul
32
+
33
+ # ── Environment variables for HuggingFace model ──
34
+ ENV MODEL_DIR=$HOME/app/models/ul
35
+ ENV HF_MODEL_REPO=BinKhoaLe1812/Driver_Behavior_OBD
36
+
37
+ # ── Models will be downloaded at runtime when app starts ──
38
+
39
+ # ── Default port ──
40
+ EXPOSE 7860
41
+
42
+ # ── Start app ──
43
+ CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
GOOGLE_DRIVE_SETUP.md ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Google Drive Integration Setup Guide
2
+
3
+ This guide explains how to set up Google Drive integration for the OBD Logger application.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Google Cloud Platform Account**: You need a Google Cloud Platform account
8
+ 2. **Google Drive API**: Enable the Google Drive API in your project
9
+ 3. **Service Account**: Create a service account with appropriate permissions
10
+ 4. **Python Dependencies**: Install the required packages
11
+
12
+ ## Installation
13
+
14
+ ### 1. Install Dependencies
15
+
16
+ The required packages are already included in `requirements.txt`:
17
+
18
+ ```bash
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ Required packages:
23
+ - `google-auth`
24
+ - `google-auth-httplib2`
25
+ - `google-auth-oauthlib`
26
+ - `google-api-python-client`
27
+
28
+ ### 2. Environment Variables
29
+
30
+ Create a `.env` file in your project root with the following variables:
31
+
32
+ ```bash
33
+ # Google Drive Configuration
34
+ GDRIVE_CREDENTIALS_JSON={"type":"service_account","project_id":"your-project","private_key_id":"...","private_key":"...","client_email":"...","client_id":"...","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_x509_cert_url":"..."}
35
+
36
+ # Optional: Custom Google Drive Folder ID
37
+ GDRIVE_FOLDER_ID=1r-wefqKbK9k9BeYDW1hXRbx4B-0Fvj5P
38
+ ```
39
+
40
+ ## Google Cloud Platform Setup
41
+
42
+ ### 1. Create a New Project
43
+
44
+ 1. Go to [Google Cloud Console](https://console.cloud.google.com/)
45
+ 2. Click "Select a project" → "New Project"
46
+ 3. Enter a project name (e.g., "OBD-Logger-Drive")
47
+ 4. Click "Create"
48
+
49
+ ### 2. Enable Google Drive API
50
+
51
+ 1. In your project, go to "APIs & Services" → "Library"
52
+ 2. Search for "Google Drive API"
53
+ 3. Click on "Google Drive API"
54
+ 4. Click "Enable"
55
+
56
+ ### 3. Create Service Account
57
+
58
+ 1. Go to "APIs & Services" → "Credentials"
59
+ 2. Click "Create Credentials" → "Service Account"
60
+ 3. Fill in the service account details:
61
+ - **Name**: `obd-logger-drive`
62
+ - **Description**: `Service account for OBD Logger Google Drive operations`
63
+ 4. Click "Create and Continue"
64
+ 5. For roles, select "Editor" (or create a custom role with minimal permissions)
65
+ 6. Click "Continue" → "Done"
66
+
67
+ ### 4. Generate Service Account Key
68
+
69
+ 1. In the service accounts list, click on your newly created service account
70
+ 2. Go to the "Keys" tab
71
+ 3. Click "Add Key" → "Create New Key"
72
+ 4. Choose "JSON" format
73
+ 5. Click "Create" - this will download a JSON file
74
+ 6. **Important**: Keep this file secure and never commit it to version control
75
+
76
+ ### 5. Share Google Drive Folder
77
+
78
+ 1. Go to [Google Drive](https://drive.google.com/)
79
+ 2. Create a new folder or use an existing one
80
+ 3. Right-click the folder → "Share"
81
+ 4. Add your service account email (found in the JSON file under `client_email`)
82
+ 5. Give it "Editor" permissions
83
+ 6. Copy the folder ID from the URL (the long string after `/folders/`)
84
+
85
+ ## Configuration
86
+
87
+ ### 1. Set Up Credentials
88
+
89
+ Copy the contents of your downloaded JSON file and set it as the `GDRIVE_CREDENTIALS_JSON` environment variable:
90
+
91
+ ```bash
92
+ export GDRIVE_CREDENTIALS_JSON='{"type":"service_account","project_id":"your-project",...}'
93
+ ```
94
+
95
+ Or add it to your `.env` file.
96
+
97
+ ### 2. Configure Folder ID
98
+
99
+ Set the `GDRIVE_FOLDER_ID` environment variable to your target folder ID:
100
+
101
+ ```bash
102
+ export GDRIVE_FOLDER_ID="your_folder_id_here"
103
+ ```
104
+
105
+ ## Usage
106
+
107
+ ### Automatic Saving
108
+
109
+ The application automatically uploads cleaned CSV files to Google Drive after processing.
110
+
111
+ ### Manual Operations
112
+
113
+ #### Initialize Drive Service
114
+
115
+ ```python
116
+ from drive_saver import DriveSaver
117
+
118
+ # Create instance
119
+ drive_saver = DriveSaver()
120
+
121
+ # Check if service is available
122
+ if drive_saver.is_service_available():
123
+ print("✅ Google Drive service ready")
124
+ else:
125
+ print("❌ Google Drive service not available")
126
+ ```
127
+
128
+ #### Upload CSV File
129
+
130
+ ```python
131
+ # Upload to default folder
132
+ success = drive_saver.upload_csv_to_drive("path/to/your/file.csv")
133
+
134
+ # Upload to specific folder
135
+ success = drive_saver.upload_csv_to_drive("path/to/your/file.csv", "custom_folder_id")
136
+ ```
137
+
138
+ #### Configuration Management
139
+
140
+ ```python
141
+ # Get current folder ID
142
+ current_folder = drive_saver.get_folder_id()
143
+
144
+ # Set new folder ID
145
+ drive_saver.set_folder_id("new_folder_id")
146
+ ```
147
+
148
+ ### Legacy Functions (Backward Compatibility)
149
+
150
+ The module maintains backward compatibility with existing code:
151
+
152
+ ```python
153
+ from drive_saver import get_drive_service, upload_to_folder
154
+
155
+ # Legacy usage
156
+ service = get_drive_service()
157
+ result = upload_to_folder(service, "file.csv", "folder_id")
158
+ ```
159
+
160
+ ## File Management
161
+
162
+ ### Supported File Types
163
+
164
+ - **CSV files**: Primary format for OBD data
165
+ - **Text files**: Other data formats
166
+ - **Binary files**: Limited support
167
+
168
+ ### File Naming
169
+
170
+ Files are uploaded with their original names. The system automatically:
171
+ - Preserves file extensions
172
+ - Maintains original timestamps
173
+ - Creates unique names if conflicts exist
174
+
175
+ ### Storage Organization
176
+
177
+ - **Default folder**: All files go to the configured default folder
178
+ - **Custom folders**: Specify different folders for different data types
179
+ - **Session-based**: Files are organized by processing sessions
180
+
181
+ ## Error Handling
182
+
183
+ ### Common Issues
184
+
185
+ 1. **Authentication Errors**
186
+ - Check service account credentials
187
+ - Verify API is enabled
188
+ - Ensure service account has proper permissions
189
+
190
+ 2. **Permission Errors**
191
+ - Verify folder sharing settings
192
+ - Check service account email is added to folder
193
+ - Ensure "Editor" or higher permissions
194
+
195
+ 3. **Quota Exceeded**
196
+ - Monitor Google Drive storage usage
197
+ - Check API quotas in Google Cloud Console
198
+ - Consider upgrading storage plan
199
+
200
+ ### Troubleshooting
201
+
202
+ #### Check Service Status
203
+
204
+ ```python
205
+ from drive_saver import DriveSaver
206
+
207
+ saver = DriveSaver()
208
+ print(f"Service available: {saver.is_service_available()}")
209
+ print(f"Current folder: {saver.get_folder_id()}")
210
+ ```
211
+
212
+ #### Test Connection
213
+
214
+ ```python
215
+ # Try uploading a small test file
216
+ test_success = drive_saver.upload_csv_to_drive("test.csv")
217
+ if test_success:
218
+ print("✅ Connection test successful")
219
+ else:
220
+ print("❌ Connection test failed")
221
+ ```
222
+
223
+ ## Security Best Practices
224
+
225
+ ### Credential Management
226
+
227
+ - **Never commit** service account JSON to version control
228
+ - **Use environment variables** for sensitive data
229
+ - **Rotate keys** regularly
230
+ - **Limit permissions** to minimum required
231
+
232
+ ### Access Control
233
+
234
+ - **Restrict folder access** to necessary users only
235
+ - **Monitor access logs** in Google Drive
236
+ - **Use organization policies** for additional security
237
+ - **Consider VPC Service Controls** for production
238
+
239
+ ### Network Security
240
+
241
+ - **HTTPS only** for all API communications
242
+ - **Firewall rules** to restrict access if needed
243
+ - **Audit logs** for suspicious activity
244
+
245
+ ## Performance Optimization
246
+
247
+ ### Upload Strategies
248
+
249
+ - **Batch uploads** for multiple files
250
+ - **Compression** for large CSV files
251
+ - **Async processing** for non-blocking operations
252
+
253
+ ### Monitoring
254
+
255
+ - **Track upload success rates**
256
+ - **Monitor file sizes and upload times**
257
+ - **Set up alerts** for failures
258
+
259
+ ## Integration with OBD Logger
260
+
261
+ ### Automatic Uploads
262
+
263
+ The system automatically uploads files after:
264
+ 1. Data processing completion
265
+ 2. CSV cleaning and validation
266
+ 3. Feature engineering
267
+ 4. Quality checks
268
+
269
+ ### File Naming Convention
270
+
271
+ Uploaded files follow the pattern:
272
+ ```
273
+ cleaned_{timestamp}.csv
274
+ ```
275
+
276
+ Where `{timestamp}` is the normalized timestamp from the processing session.
277
+
278
+ ### Error Recovery
279
+
280
+ If uploads fail:
281
+ - Files remain in local storage
282
+ - Errors are logged for debugging
283
+ - Processing continues without interruption
284
+ - Manual retry options available
285
+
286
+ ## Advanced Configuration
287
+
288
+ ### Custom Scopes
289
+
290
+ Modify the authentication scopes in `drive_saver.py`:
291
+
292
+ ```python
293
+ scopes = [
294
+ "https://www.googleapis.com/auth/drive",
295
+ "https://www.googleapis.com/auth/drive.file" # More restrictive
296
+ ]
297
+ ```
298
+
299
+ ### Retry Logic
300
+
301
+ The system includes automatic retry logic for:
302
+ - Network timeouts
303
+ - Rate limiting
304
+ - Temporary service unavailability
305
+
306
+ ### Logging
307
+
308
+ Comprehensive logging includes:
309
+ - Upload success/failure
310
+ - File details and metadata
311
+ - Performance metrics
312
+ - Error details for debugging
313
+
314
+ ## Support and Maintenance
315
+
316
+ ### Regular Tasks
317
+
318
+ 1. **Monitor storage usage** in Google Drive
319
+ 2. **Check API quotas** in Google Cloud Console
320
+ 3. **Review access logs** for security
321
+ 4. **Update service account keys** as needed
322
+
323
+ ### Troubleshooting Resources
324
+
325
+ - [Google Drive API Documentation](https://developers.google.com/drive/api)
326
+ - [Google Cloud Console](https://console.cloud.google.com/)
327
+ - [Google Drive Help](https://support.google.com/drive/)
328
+ - Application logs and error messages
329
+
330
+ ### Getting Help
331
+
332
+ For issues with the OBD Logger integration:
333
+ 1. Check application logs
334
+ 2. Verify environment variables
335
+ 3. Test with simple file uploads
336
+ 4. Review Google Cloud Console for errors
MONGODB_SETUP.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MongoDB Integration Setup Guide
2
+
3
+ This guide explains how to set up MongoDB integration for the OBD Logger application.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **MongoDB Atlas Account**: You need a MongoDB Atlas account (free tier available)
8
+ 2. **Python Dependencies**: Install the required packages
9
+
10
+ ## Installation
11
+
12
+ ### 1. Install Dependencies
13
+
14
+ ```bash
15
+ pip install pymongo
16
+ ```
17
+
18
+ Or update your requirements.txt and run:
19
+ ```bash
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ ### 2. Environment Variables
24
+
25
+ Create a `.env` file in your project root with the following variables:
26
+
27
+ ```bash
28
+ # Google Drive Configuration
29
+ GDRIVE_CREDENTIALS_JSON={"type":"service_account","project_id":"your-project",...}
30
+
31
+ # MongoDB Atlas Connection String
32
+ MONGO_URI=mongodb+srv://username:password@cluster.mongodb.net/obd_logger?retryWrites=true&w=majority
33
+
34
+ # Optional: Custom Google Drive Folder ID
35
+ GDRIVE_FOLDER_ID=1r-wefqKbK9k9BeYDW1hXRbx4B-0Fvj5P
36
+ ```
37
+
38
+ ## MongoDB Atlas Setup
39
+
40
+ ### 1. Create Cluster
41
+ 1. Go to [MongoDB Atlas](https://cloud.mongodb.com/)
42
+ 2. Create a free cluster
43
+ 3. Choose your preferred cloud provider and region
44
+
45
+ ### 2. Database Access
46
+ 1. Go to "Database Access" in the left sidebar
47
+ 2. Click "Add New Database User"
48
+ 3. Choose "Password" authentication
49
+ 4. Set username and password (save these!)
50
+ 5. Set privileges to "Read and write to any database"
51
+
52
+ ### 3. Network Access
53
+ 1. Go to "Network Access" in the left sidebar
54
+ 2. Click "Add IP Address"
55
+ 3. For development: Click "Allow Access from Anywhere" (0.0.0.0/0)
56
+ 4. For production: Add your specific IP addresses
57
+
58
+ ### 4. Get Connection String
59
+ 1. Go to "Clusters" in the left sidebar
60
+ 2. Click "Connect" on your cluster
61
+ 3. Choose "Connect your application"
62
+ 4. Copy the connection string
63
+ 5. Replace `<username>`, `<password>`, and `<dbname>` with your values
64
+
65
+ ## Usage
66
+
67
+ ### Automatic Saving
68
+ The application now automatically saves cleaned data to both Google Drive and MongoDB after processing.
69
+
70
+ ### Manual Operations
71
+
72
+ #### Check MongoDB Status
73
+ ```bash
74
+ GET /mongo/status
75
+ ```
76
+
77
+ #### Get Session Summary
78
+ ```bash
79
+ GET /mongo/sessions
80
+ ```
81
+
82
+ #### Query Data
83
+ ```bash
84
+ GET /mongo/query?session_id=session_20231201_120000&driving_style=aggressive&limit=100
85
+ ```
86
+
87
+ #### Save CSV Directly to MongoDB
88
+ ```bash
89
+ POST /mongo/save-csv
90
+ # Upload CSV file with optional session_id parameter
91
+ ```
92
+
93
+ ## Data Structure
94
+
95
+ Each document in MongoDB contains:
96
+ - All OBD sensor data from the original CSV
97
+ - `session_id`: Unique identifier for the data session
98
+ - `imported_at`: Timestamp when data was imported
99
+ - `record_index`: Original row index from CSV
100
+ - `timestamp`: OBD data timestamp (converted to datetime)
101
+ - `driving_style`: Driving style classification
102
+
103
+ ## Performance Features
104
+
105
+ - **Indexes**: Automatic creation of indexes on timestamp, driving_style, and session_id
106
+ - **Connection Pooling**: Efficient connection management
107
+ - **Batch Operations**: Bulk insert for better performance
108
+ - **Error Handling**: Graceful fallback if MongoDB is unavailable
109
+
110
+ ## Troubleshooting
111
+
112
+ ### Connection Issues
113
+ 1. Check your MongoDB URI format
114
+ 2. Verify network access settings in Atlas
115
+ 3. Check username/password credentials
116
+ 4. Ensure cluster is running
117
+
118
+ ### Data Import Issues
119
+ 1. Check CSV file format
120
+ 2. Verify data types in your CSV
121
+ 3. Check application logs for specific error messages
122
+
123
+ ### Performance Issues
124
+ 1. Monitor database indexes
125
+ 2. Check connection pool settings
126
+ 3. Consider data partitioning for large datasets
127
+
128
+ ## Security Notes
129
+
130
+ - Never commit your `.env` file to version control
131
+ - Use strong passwords for database users
132
+ - Restrict network access to necessary IP addresses only
133
+ - Consider using VPC peering for production deployments
OBD/obd_analyzer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import argparse
4
+ import os
5
+
6
+
7
+ DRIVING_STYLE_PASSIVE = "Passive"
8
+ DRIVING_STYLE_MODERATE = "Moderate"
9
+ DRIVING_STYLE_AGGRESSIVE = "Aggressive"
10
+ DRIVING_STYLE_UNKNOWN = "UNKNOWN_STYLE"
11
+
12
+ ROAD_TYPE_LOCAL = "Local"
13
+ ROAD_TYPE_MAIN = "Main"
14
+ ROAD_TYPE_HIGHWAY = "Highway"
15
+ ROAD_TYPE_UNKNOWN = "UNKNOWN_ROAD"
16
+
17
+ TRAFFIC_CONDITION_LIGHT = "Light"
18
+ TRAFFIC_CONDITION_MODERATE = "Moderate"
19
+ TRAFFIC_CONDITION_HEAVY = "Heavy"
20
+ TRAFFIC_CONDITION_UNKNOWN = "UNKNOWN_TRAFFIC"
21
+
22
+
23
+ KPH_TO_MPS = 1 / 3.6
24
+ G_ACCELERATION = 9.80665
25
+ MIN_MOVING_SPEED_KPH = 2 # have to be moving
26
+
27
+ AGGRESSIVE_RPM_ENTRY_THRESHOLD = 2700
28
+ AGGRESSIVE_THROTTLE_ENTRY_THRESHOLD = 40
29
+ AGGRESSIVE_RPM_HOLD_THRESHOLD = 2300
30
+ HARSH_BRAKING_THRESHOLD_G = -0.25
31
+
32
+ # roc
33
+ AGGRESSIVE_RPM_ROC_THRESHOLD = 500
34
+ AGGRESSIVE_THROTTLE_ROC_THRESHOLD = 45
35
+ POSITIVE_ACCEL_FOR_ROC_CHECK_G = 0.1
36
+
37
+ MODERATE_RPM_THRESHOLD = 2100
38
+ MODERATE_THROTTLE_THRESHOLD = 25
39
+
40
+ MIN_DATA_POINTS_FOR_ROC = 2
41
+
42
+ def load_and_preprocess_data(csv_filepath):
43
+ """Loads OBD data from CSV and preprocesses it."""
44
+ if not os.path.exists(csv_filepath):
45
+ print(f"Error: File not found at {csv_filepath}")
46
+ return None
47
+
48
+ try:
49
+ df = pd.read_csv(csv_filepath)
50
+ except Exception as e:
51
+ print(f"Error loading CSV {csv_filepath}: {e}")
52
+ return None
53
+
54
+ print(f"Successfully loaded {csv_filepath} with {len(df)} rows.")
55
+
56
+ if 'timestamp' not in df.columns:
57
+ print("Error: 'timestamp' column is missing from the CSV.")
58
+ return None
59
+
60
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
61
+ df = df.sort_values(by='timestamp').reset_index(drop=True)
62
+
63
+ df['delta_time_s'] = df['timestamp'].diff().dt.total_seconds()
64
+ if not df.empty:
65
+ df.loc[0, 'delta_time_s'] = 0
66
+ else:
67
+ # Handle empty DataFrame after potential filtering or if it was empty to begin with
68
+ return df # Or handle error appropriately
69
+
70
+ numeric_cols = ['SPEED', 'RPM', 'THROTTLE_POS']
71
+ for col in numeric_cols:
72
+ if col in df.columns:
73
+ df[col] = pd.to_numeric(df[col], errors='coerce')
74
+ else:
75
+ print(f"Warning: Column {col} not found. It will be filled with NaN.")
76
+ df[col] = np.nan
77
+
78
+ df[numeric_cols] = df[numeric_cols].fillna(method='ffill').fillna(0)
79
+
80
+ if 'SPEED' in df.columns:
81
+ df['SPEED_mps'] = df['SPEED'] * KPH_TO_MPS
82
+ else:
83
+ df['SPEED_mps'] = 0
84
+
85
+ if len(df) >= MIN_DATA_POINTS_FOR_ROC:
86
+ df['acceleration_mps2'] = df['SPEED_mps'].diff() / df['delta_time_s']
87
+ df['acceleration_mps2'] = df['acceleration_mps2'].replace([np.inf, -np.inf], 0).fillna(0)
88
+ if not df.empty: df.loc[0, 'acceleration_mps2'] = 0
89
+ df['acceleration_g'] = df['acceleration_mps2'] / G_ACCELERATION
90
+ if not df.empty: df.loc[0, 'acceleration_g'] = 0
91
+ df['acceleration_g'] = df['acceleration_g'].fillna(0)
92
+
93
+ if 'RPM' in df.columns:
94
+ df['RPM_roc'] = df['RPM'].diff() / df['delta_time_s']
95
+ df['RPM_roc'] = df['RPM_roc'].replace([np.inf, -np.inf], 0).fillna(0)
96
+ if not df.empty: df.loc[0, 'RPM_roc'] = 0
97
+ else:
98
+ df['RPM_roc'] = 0
99
+
100
+ if 'THROTTLE_POS' in df.columns:
101
+ df['THROTTLE_roc'] = df['THROTTLE_POS'].diff() / df['delta_time_s']
102
+ df['THROTTLE_roc'] = df['THROTTLE_roc'].replace([np.inf, -np.inf], 0).fillna(0)
103
+ if not df.empty: df.loc[0, 'THROTTLE_roc'] = 0
104
+ else:
105
+ df['THROTTLE_roc'] = 0
106
+ else:
107
+ # Not enough data for RoC calculations, fill with 0 or handle as error
108
+ df['acceleration_mps2'] = 0
109
+ df['acceleration_g'] = 0
110
+ df['RPM_roc'] = 0
111
+ df['THROTTLE_roc'] = 0
112
+ print("Warning: Not enough data points for full RoC calculations. Output might be limited.")
113
+
114
+ print("Preprocessing complete.")
115
+ return df
116
+
117
+ def classify_driving_style_stateful(df):
118
+ if df.empty or not all(col in df.columns for col in ['RPM', 'THROTTLE_POS', 'SPEED', 'acceleration_g']):
119
+ print("Warning: Missing one or more required columns for stateful classification (RPM, THROTTLE_POS, SPEED, acceleration_g).")
120
+ return pd.Series([DRIVING_STYLE_UNKNOWN] * len(df), index=df.index, dtype=str)
121
+
122
+ driving_styles = [DRIVING_STYLE_UNKNOWN] * len(df)
123
+ current_style = DRIVING_STYLE_PASSIVE
124
+
125
+ for i in range(len(df)):
126
+ rpm = df.loc[i, 'RPM']
127
+ throttle = df.loc[i, 'THROTTLE_POS']
128
+ speed_kph = df.loc[i, 'SPEED']
129
+ accel_g = df.loc[i, 'acceleration_g']
130
+ rpm_roc = df.loc[i, 'RPM_roc']
131
+ throttle_roc = df.loc[i, 'THROTTLE_roc']
132
+
133
+ row_style = DRIVING_STYLE_PASSIVE
134
+ is_moving = speed_kph > MIN_MOVING_SPEED_KPH
135
+
136
+ is_hard_braking_trigger = accel_g < HARSH_BRAKING_THRESHOLD_G and is_moving
137
+
138
+ is_high_abs_rpm_throttle_trigger = (rpm > AGGRESSIVE_RPM_ENTRY_THRESHOLD and
139
+ throttle > AGGRESSIVE_THROTTLE_ENTRY_THRESHOLD and
140
+ is_moving)
141
+
142
+ is_actively_accelerating = accel_g > POSITIVE_ACCEL_FOR_ROC_CHECK_G
143
+
144
+ is_high_roc_trigger = (is_moving and
145
+ is_actively_accelerating and
146
+ (rpm_roc > AGGRESSIVE_RPM_ROC_THRESHOLD or
147
+ throttle_roc > AGGRESSIVE_THROTTLE_ROC_THRESHOLD))
148
+
149
+ is_currently_aggressive_event = is_hard_braking_trigger or is_high_abs_rpm_throttle_trigger or is_high_roc_trigger
150
+
151
+ if current_style == DRIVING_STYLE_AGGRESSIVE:
152
+ if is_currently_aggressive_event:
153
+ row_style = DRIVING_STYLE_AGGRESSIVE
154
+ elif rpm > AGGRESSIVE_RPM_HOLD_THRESHOLD and is_moving:
155
+ row_style = DRIVING_STYLE_AGGRESSIVE
156
+ else:
157
+ if (rpm > MODERATE_RPM_THRESHOLD or throttle > MODERATE_THROTTLE_THRESHOLD) and is_moving:
158
+ row_style = DRIVING_STYLE_MODERATE
159
+ else:
160
+ row_style = DRIVING_STYLE_PASSIVE
161
+ else:
162
+ if is_currently_aggressive_event:
163
+ row_style = DRIVING_STYLE_AGGRESSIVE
164
+ else:
165
+ if (rpm > MODERATE_RPM_THRESHOLD or throttle > MODERATE_THROTTLE_THRESHOLD) and is_moving:
166
+ row_style = DRIVING_STYLE_MODERATE
167
+ else:
168
+ row_style = DRIVING_STYLE_PASSIVE
169
+
170
+ driving_styles[i] = row_style
171
+ current_style = row_style
172
+
173
+ print("Stateful driving style classification complete.")
174
+ return pd.Series(driving_styles, index=df.index)
175
+
176
+ def main():
177
+ parser = argparse.ArgumentParser(description="Analyze OBD CSV log data for driving behavior (stateful).")
178
+ parser.add_argument("csv_filepath", help="Path to the OBD log CSV file.")
179
+ parser.add_argument("--output_csv", help="Path to save the analyzed data CSV file.", default=None)
180
+ args = parser.parse_args()
181
+
182
+ df = load_and_preprocess_data(args.csv_filepath)
183
+
184
+ if df is None or df.empty:
185
+ print("No data to process after loading or preprocessing.")
186
+ return
187
+
188
+ df['driving_style_analyzed'] = classify_driving_style_stateful(df)
189
+
190
+ print("\n--- Analysis Summary ---")
191
+ print("Driving Style Distribution (Analyzed):")
192
+ counts = df['driving_style_analyzed'].value_counts(dropna=False)
193
+ percentages = df['driving_style_analyzed'].value_counts(normalize=True, dropna=False) * 100
194
+ summary_df = pd.DataFrame({'Count': counts, 'Percentage': percentages})
195
+ print(summary_df)
196
+
197
+ if args.output_csv:
198
+ try:
199
+ output_path = args.output_csv
200
+ output_dir = os.path.dirname(output_path)
201
+ if output_dir and not os.path.exists(output_dir):
202
+ os.makedirs(output_dir)
203
+ df.to_csv(output_path, index=False)
204
+ print(f"\nAnalyzed data saved to {output_path}")
205
+ except Exception as e:
206
+ print(f"Error saving output CSV to {args.output_csv}: {e}")
207
+ else:
208
+ print("\n--- First 20 Rows of Analyzed Data (showing key fields) ---")
209
+ display_cols = ['timestamp', 'SPEED', 'RPM', 'THROTTLE_POS', 'acceleration_g', 'driving_style_analyzed']
210
+ display_cols = [col for col in display_cols if col in df.columns]
211
+ if display_cols: print(df[display_cols].head(20))
212
+ else: print("Key display columns not found in DataFrame.")
213
+
214
+ if __name__ == "__main__":
215
+ main()
OBD/obd_logger.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import obd
2
+ import time
3
+ import datetime
4
+ import csv
5
+ import os
6
+ from collections import deque
7
+ import numpy as np
8
+ import shutil
9
+ import subprocess
10
+
11
+ DRIVING_STYLE_PASSIVE = "Passive"
12
+ DRIVING_STYLE_MODERATE = "Moderate"
13
+ DRIVING_STYLE_AGGRESSIVE = "Aggressive"
14
+ DRIVING_STYLE_UNKNOWN = "UNKNOWN_STYLE"
15
+
16
+ ROAD_TYPE_LOCAL = "Local"
17
+ ROAD_TYPE_MAIN = "Main"
18
+ ROAD_TYPE_HIGHWAY = "Highway"
19
+ ROAD_TYPE_UNKNOWN = "UNKNOWN_ROAD"
20
+
21
+ TRAFFIC_CONDITION_LIGHT = "Light"
22
+ TRAFFIC_CONDITION_MODERATE = "Moderate"
23
+ TRAFFIC_CONDITION_HEAVY = "Heavy"
24
+ TRAFFIC_CONDITION_UNKNOWN = "UNKNOWN_TRAFFIC"
25
+
26
+ # Rolling Average Configuration
27
+ ROLLING_WINDOW_SIZE = 20 # 6 seconds
28
+ MIN_SAMPLES_FOR_CLASSIFICATION = 10
29
+
30
+ # ROC needs tuning
31
+ SHORT_ROC_WINDOW_SIZE = 3
32
+ MIN_SAMPLES_FOR_ROC_CHECK = SHORT_ROC_WINDOW_SIZE
33
+ ROC_THROTTLE_AGGRESSIVE_THRESHOLD = 25.0
34
+ ROC_RPM_AGGRESSIVE_THRESHOLD = 700.0
35
+ ROC_SPEED_AGGRESSIVE_THRESHOLD = 8.0
36
+ MIN_RPM_FOR_AGGRESSIVE_TRIGGER = 1000.0
37
+ AGGRESSIVE_EVENT_COOLDOWN_SAMPLES = 15
38
+
39
+ HIGH_FREQUENCY_PIDS = [
40
+ obd.commands.RPM,
41
+ obd.commands.THROTTLE_POS,
42
+ obd.commands.SPEED,
43
+ ]
44
+
45
+ LOW_FREQUENCY_PIDS_POOL = [
46
+ obd.commands.FUEL_PRESSURE,
47
+ obd.commands.ENGINE_LOAD,
48
+ obd.commands.COOLANT_TEMP,
49
+ obd.commands.INTAKE_TEMP,
50
+ obd.commands.TIMING_ADVANCE,
51
+ obd.commands.MAF,
52
+ obd.commands.INTAKE_PRESSURE,
53
+ obd.commands.SHORT_FUEL_TRIM_1,
54
+ obd.commands.LONG_FUEL_TRIM_1,
55
+ obd.commands.SHORT_FUEL_TRIM_2,
56
+ obd.commands.LONG_FUEL_TRIM_2,
57
+ obd.commands.COMMANDED_EQUIV_RATIO,
58
+ obd.commands.O2_B1S2,
59
+ obd.commands.O2_B2S2,
60
+ obd.commands.O2_S1_WR_VOLTAGE,
61
+ obd.commands.COMMANDED_EGR,
62
+ ]
63
+
64
+ ALL_PIDS_TO_LOG = HIGH_FREQUENCY_PIDS + LOW_FREQUENCY_PIDS_POOL
65
+
66
+ CSV_FILENAME_BASE = "obd_data_log"
67
+ # Define new structured log directories relative to the OBD_Logger/OBD directory
68
+ LOGS_BASE_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "logs") # Corrected: Up two levels to Base, then into logs
69
+ ORIGINAL_CSV_DIR = os.path.join(LOGS_BASE_DIR, "OriginalCSV")
70
+ DUPLICATE_CSV_DIR = os.path.join(LOGS_BASE_DIR, "DuplicateCSV")
71
+
72
+ WIFI_ADAPTER_HOST = "192.168.0.10"
73
+ WIFI_ADAPTER_PORT = 35000
74
+
75
+ WIFI_PROTOCOL = "6"
76
+ USE_WIFI_SETTINGS = False # using socat to mimic serial connection
77
+
78
+ def get_pid_value(connection, pid_command):
79
+ """Queries a PID and returns its value, or None if not available or error."""
80
+ try:
81
+ response = connection.query(pid_command, force=True)
82
+ if response.is_null() or response.value is None:
83
+ return None
84
+ if hasattr(response.value, 'magnitude'):
85
+ return response.value.magnitude
86
+ return response.value
87
+ except Exception as e:
88
+ print(f"Error querying {pid_command.name}: {e}")
89
+ return None
90
+
91
+ def perform_logging_session():
92
+ connection = None
93
+ print("Starting OBD-II Data Logger...")
94
+ print("Classifications (Style, Road, Traffic) will be determined automatically.")
95
+
96
+
97
+ initial_driving_style = ""
98
+ initial_road_type = ""
99
+ initial_traffic_condition = ""
100
+
101
+ BASE_LOG_INTERVAL = .3 # for high frequency data
102
+ LOW_FREQUENCY_GROUP_POLL_INTERVAL = 90.0 # Interval in seconds to poll one group of LF PIDs
103
+ NUM_LOW_FREQUENCY_GROUPS = 3
104
+
105
+ # Prepare Low-Frequency PID groups
106
+ low_frequency_pid_groups = []
107
+ if LOW_FREQUENCY_PIDS_POOL:
108
+ chunk_size = (len(LOW_FREQUENCY_PIDS_POOL) + NUM_LOW_FREQUENCY_GROUPS - 1) // NUM_LOW_FREQUENCY_GROUPS
109
+ for i in range(0, len(LOW_FREQUENCY_PIDS_POOL), chunk_size):
110
+ low_frequency_pid_groups.append(LOW_FREQUENCY_PIDS_POOL[i:i + chunk_size])
111
+
112
+ if not low_frequency_pid_groups: # Handle case with no LF PIDs
113
+ low_frequency_pid_groups.append([])
114
+ NUM_LOW_FREQUENCY_GROUPS = 1
115
+
116
+ last_low_frequency_group_poll_time = time.monotonic()
117
+ current_low_frequency_group_index = 0
118
+
119
+ current_pid_values = {pid.name: '' for pid in ALL_PIDS_TO_LOG}
120
+
121
+ # Create log directories
122
+ for dir_path in [ORIGINAL_CSV_DIR, DUPLICATE_CSV_DIR]: # Add ANALYZED_OUTPUT_DIR if used
123
+ try:
124
+ os.makedirs(dir_path, exist_ok=True)
125
+ print(f"Ensured directory exists: {dir_path}")
126
+ except OSError as e:
127
+ print(f"Error creating directory {dir_path}: {e}. Attempting to use current directory.")
128
+ # Fallback logic may be needed if creation fails critically
129
+ if dir_path == ORIGINAL_CSV_DIR: # Critical for saving original log
130
+ print("Cannot create original log directory. Exiting.")
131
+ return None
132
+
133
+ current_session_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
134
+ csv_file_name_only = f"{CSV_FILENAME_BASE}_{current_session_timestamp}.csv"
135
+ original_csv_filepath = os.path.join(ORIGINAL_CSV_DIR, csv_file_name_only)
136
+
137
+ try:
138
+ if USE_WIFI_SETTINGS:
139
+ print(f"Attempting to connect to WiFi adapter at {WIFI_ADAPTER_HOST}:{WIFI_ADAPTER_PORT} using protocol {WIFI_PROTOCOL}...")
140
+ connection = obd.OBD(protocol=WIFI_PROTOCOL,
141
+ host=WIFI_ADAPTER_HOST,
142
+ port=WIFI_ADAPTER_PORT,
143
+ fast=False,
144
+ timeout=30)
145
+ else:
146
+ print("Attempting to connect via socat PTY /dev/ttys011...")
147
+ connection = obd.OBD("/dev/ttys086", fast=True, timeout=30) # Auto-scan for USB/Bluetooth
148
+
149
+ if not connection.is_connected():
150
+ print("Failed to connect to OBD-II adapter.")
151
+ print(f"Connection status: {connection.status()}")
152
+ return None
153
+
154
+ print(f"Successfully connected to OBD-II adapter: {connection.port_name()}")
155
+ print(f"Adapter status: {connection.status()}")
156
+ print(f"Supported PIDs (sample):")
157
+ supported_commands = connection.supported_commands
158
+ for i, cmd in enumerate(supported_commands):
159
+ print(f" - {cmd.name}")
160
+ if not supported_commands:
161
+ print("No commands")
162
+
163
+ # Creating initial full PID sample to have fully populated rows from beginning
164
+ print("\nPerforming initial full PID sample...")
165
+ initial_log_entry = {
166
+ 'timestamp': datetime.datetime.now().isoformat(),
167
+ 'driving_style': initial_driving_style,
168
+ 'road_type': initial_road_type,
169
+ 'traffic_condition': initial_traffic_condition
170
+ }
171
+
172
+ print("Polling initial High-Frequency PIDs...")
173
+ for pid_command in HIGH_FREQUENCY_PIDS:
174
+ value = get_pid_value(connection, pid_command)
175
+ current_pid_values[pid_command.name] = value if value is not None else ''
176
+ initial_log_entry[pid_command.name] = current_pid_values[pid_command.name]
177
+
178
+ print("Polling initial Low-Frequency PIDs (all groups)...")
179
+ if low_frequency_pid_groups and low_frequency_pid_groups[0]: # Check if there are any LF PIDs
180
+ for group in low_frequency_pid_groups:
181
+ for pid_command in group:
182
+ value = get_pid_value(connection, pid_command)
183
+ current_pid_values[pid_command.name] = value if value is not None else ''
184
+ initial_log_entry[pid_command.name] = current_pid_values[pid_command.name]
185
+ else:
186
+ print("No Low-Frequency PIDs to poll for initial sample.")
187
+
188
+ for pid_obj in ALL_PIDS_TO_LOG:
189
+ if pid_obj.name not in initial_log_entry:
190
+ initial_log_entry[pid_obj.name] = '' # Default to empty if somehow missed
191
+
192
+ except Exception as e:
193
+ print(f"An error occurred during connection or initial PID sample: {e}")
194
+ if connection and connection.is_connected():
195
+ connection.close()
196
+ return None
197
+
198
+ file_exists = os.path.isfile(original_csv_filepath)
199
+ try:
200
+ with open(original_csv_filepath, 'a', newline='') as csvfile:
201
+ # Add new columns for analyzer output, they will be empty initially from logger
202
+ header_names = ['timestamp',
203
+ 'driving_style', 'road_type', 'traffic_condition', # Original placeholder columns
204
+ 'driving_style_analyzed', 'road_type_analyzed', 'traffic_condition_analyzed' # For analyzer
205
+ ] + [pid.name for pid in ALL_PIDS_TO_LOG]
206
+
207
+ # Remove duplicates if any PID name is already in the first part
208
+ processed_headers = []
209
+ for item in header_names:
210
+ if item not in processed_headers:
211
+ processed_headers.append(item)
212
+ header_names = processed_headers
213
+
214
+ writer = csv.DictWriter(csvfile, fieldnames=header_names)
215
+
216
+ if not file_exists or os.path.getsize(original_csv_filepath) == 0:
217
+ writer.writeheader()
218
+ print(f"Created new CSV file: {original_csv_filepath} with headers: {header_names}")
219
+
220
+ if initial_log_entry:
221
+ # Add placeholder columns for analyzer to the initial entry
222
+ initial_log_entry['driving_style_analyzed'] = ''
223
+ initial_log_entry['road_type_analyzed'] = ''
224
+ initial_log_entry['traffic_condition_analyzed'] = ''
225
+ writer.writerow(initial_log_entry)
226
+ csvfile.flush()
227
+ print(f"Logged initial full sample. Style: {initial_driving_style}, Road: {initial_road_type}, Traffic: {initial_traffic_condition}.")
228
+
229
+ last_low_frequency_group_poll_time = time.monotonic()
230
+ current_low_frequency_group_index = 0
231
+
232
+ print(f"\nLogging high-frequency data every {BASE_LOG_INTERVAL} second(s).")
233
+ print(f"Polling one group of low-frequency PIDs every {LOW_FREQUENCY_GROUP_POLL_INTERVAL} second(s).")
234
+ print(f"Low-frequency PIDs divided into {len(low_frequency_pid_groups)} groups.")
235
+
236
+ log_count = 0
237
+ while True:
238
+ loop_start_time = time.monotonic()
239
+ current_datetime = datetime.datetime.now()
240
+ timestamp_iso = current_datetime.isoformat()
241
+
242
+ hf_reads = 0
243
+ for pid_command in HIGH_FREQUENCY_PIDS:
244
+ value = get_pid_value(connection, pid_command)
245
+ current_pid_values[pid_command.name] = value if value is not None else ''
246
+ if value is not None:
247
+ hf_reads += 1
248
+
249
+ lf_reads_this_cycle = 0
250
+ lf_group_polled_this_cycle = "None"
251
+ if low_frequency_pid_groups and (time.monotonic() - last_low_frequency_group_poll_time) >= LOW_FREQUENCY_GROUP_POLL_INTERVAL:
252
+ group_to_poll = low_frequency_pid_groups[current_low_frequency_group_index]
253
+ lf_group_polled_this_cycle = f"Group {current_low_frequency_group_index + 1}/{len(low_frequency_pid_groups)}"
254
+
255
+ for pid_command in group_to_poll:
256
+ value = get_pid_value(connection, pid_command)
257
+ current_pid_values[pid_command.name] = value if value is not None else ''
258
+ if value is not None:
259
+ lf_reads_this_cycle +=1
260
+ else:
261
+ print(f"Warning: Could not read LF PID {pid_command.name}")
262
+
263
+ last_low_frequency_group_poll_time = time.monotonic()
264
+ current_low_frequency_group_index = (current_low_frequency_group_index + 1) % len(low_frequency_pid_groups)
265
+
266
+
267
+ final_log_entry = {
268
+ 'timestamp': timestamp_iso,
269
+ 'driving_style': initial_driving_style,
270
+ 'road_type': initial_road_type,
271
+ 'traffic_condition': initial_traffic_condition,
272
+ 'driving_style_analyzed': '',
273
+ 'road_type_analyzed': '',
274
+ 'traffic_condition_analyzed': ''
275
+ }
276
+ # Add all PID values for this cycle from current_pid_values
277
+ for pid_obj in ALL_PIDS_TO_LOG:
278
+ final_log_entry[pid_obj.name] = current_pid_values.get(pid_obj.name, '')
279
+
280
+ writer.writerow(final_log_entry)
281
+ csvfile.flush()
282
+
283
+ log_count += 1
284
+ if log_count % 10 == 0:
285
+ status_msg = f"Logged entry {log_count} - HF PIDs Read: {hf_reads}/{len(HIGH_FREQUENCY_PIDS)}"
286
+ if lf_reads_this_cycle > 0 or lf_group_polled_this_cycle != "None":
287
+ status_msg += f" - LF PIDs ({lf_group_polled_this_cycle}) Read: {lf_reads_this_cycle}/unknown_total_for_group_easily"
288
+ print(status_msg)
289
+
290
+ elapsed_time_in_loop = time.monotonic() - loop_start_time
291
+ sleep_duration = max(0, BASE_LOG_INTERVAL - elapsed_time_in_loop)
292
+ time.sleep(sleep_duration)
293
+
294
+ except KeyboardInterrupt:
295
+ print("\nStopping data logging due to user interruption (Ctrl+C).")
296
+ except Exception as e:
297
+ print(f"An error occurred during logging: {e}")
298
+ finally:
299
+ if connection and connection.is_connected():
300
+ print("Closing OBD-II connection.")
301
+ connection.close()
302
+ print(f"Data logging stopped. Original CSV file '{original_csv_filepath}' saved.")
303
+
304
+ return original_csv_filepath
305
+
306
+ def duplicate_csv(original_filepath):
307
+ if not original_filepath or not os.path.exists(original_filepath):
308
+ print(f"Error: Original CSV not found for duplication: {original_filepath}")
309
+ return None
310
+
311
+ # Ensure DUPLICATE_CSV_DIR exists (it should have been created by perform_logging_session)
312
+ os.makedirs(DUPLICATE_CSV_DIR, exist_ok=True)
313
+
314
+ # Get just the filename from the original path
315
+ original_filename = os.path.basename(original_filepath)
316
+ base, ext = os.path.splitext(original_filename)
317
+
318
+ # Construct new filename for the duplicate
319
+ duplicate_filename = f"{base}_to_analyze{ext}" # Suffix to distinguish
320
+ duplicate_filepath = os.path.join(DUPLICATE_CSV_DIR, duplicate_filename)
321
+
322
+ try:
323
+ shutil.copy2(original_filepath, duplicate_filepath)
324
+ print(f"Successfully duplicated CSV to: {duplicate_filepath}")
325
+ return duplicate_filepath
326
+ except Exception as e:
327
+ print(f"Error duplicating CSV {original_filepath} to {duplicate_filepath}: {e}")
328
+ return None
329
+
330
+ def run_analyzer_on_csv(csv_to_analyze_path):
331
+ if not csv_to_analyze_path or not os.path.exists(csv_to_analyze_path):
332
+ print(f"Error: Analyzer input CSV not found: {csv_to_analyze_path}")
333
+ return
334
+
335
+ # Analyzer script is in the same directory as this logger script
336
+ analyzer_script_path = os.path.join(os.path.dirname(__file__), "obd_analyzer.py")
337
+
338
+ if not os.path.exists(analyzer_script_path):
339
+ print(f"CRITICAL Error: Analyzer script not found at {analyzer_script_path}")
340
+ return
341
+
342
+ analyzed_file_basename = os.path.basename(csv_to_analyze_path).replace("_to_analyze.csv", "_final_analyzed.csv")
343
+ final_output_path = os.path.join(DUPLICATE_CSV_DIR, analyzed_file_basename)
344
+
345
+ command = [
346
+ "python",
347
+ analyzer_script_path,
348
+ csv_to_analyze_path,
349
+ "--output_csv",
350
+ final_output_path
351
+ ]
352
+
353
+ print(f"Running analyzer: {' '.join(command)}")
354
+ try:
355
+ process = subprocess.run(command, check=True, capture_output=True, text=True, cwd=os.path.dirname(__file__))
356
+ print("Analyzer Output:\n", process.stdout)
357
+ if process.stderr: print("Analyzer Errors:\n", process.stderr)
358
+ print(f"Analyzer finished. Output saved to {final_output_path}")
359
+ except subprocess.CalledProcessError as e:
360
+ print(f"Error running analyzer: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}")
361
+ except FileNotFoundError:
362
+ print(f"Error: 'python' or analyzer script not found ({analyzer_script_path}).")
363
+
364
+ if __name__ == "__main__":
365
+ original_log_file = perform_logging_session()
366
+
367
+ if original_log_file and os.path.exists(original_log_file):
368
+ duplicated_log_file = duplicate_csv(original_log_file)
369
+
370
+ if duplicated_log_file:
371
+ run_analyzer_on_csv(duplicated_log_file)
372
+ print(f"Process complete. Original log: {original_log_file}, Analyzed log copy: {duplicated_log_file}")
373
+ else:
374
+ print("OBD logging did not produce a valid CSV file. Skipping analysis.")
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OBD Logger
3
+ emoji: 🚗
4
+ colorFrom: gray
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ short_description: OBD-logging FastAPI server with data processing pipelines
10
+ ---
11
+
12
+ # OBD Logger
13
+
14
+ A comprehensive OBD-II data logging and processing system built with FastAPI, featuring advanced data cleaning, Google Drive integration, MongoDB storage capabilities, and **Reinforcement Learning from Human Feedback (RLHF)** for driver behavior classification.
15
+
16
+ ## Features
17
+
18
+ - **Real-time OBD-II Data Ingestion**: Stream and process OBD sensor data in real-time
19
+ - **Advanced Data Cleaning**: Intelligent gap detection, KNN imputation, and outlier handling
20
+ - **Multi-Storage Architecture**:
21
+ - Google Drive integration for CSV storage
22
+ - Firebase for structured data storage and querying
23
+ - MongoDB Atlas for structured data storage and querying
24
+ - **Driver Behavior Classification**: XGBoost-based ML model for driving style prediction
25
+ - **RLHF Training System**: Continuous model improvement through human feedback
26
+ - **Data Visualization**: Automatic generation of correlation heatmaps and trend plots
27
+ - **RESTful API**: Comprehensive endpoints for data management and retrieval
28
+ - **Web Dashboard**: User-friendly interface for monitoring and control
29
+ - **Model Versioning**: Semantic versioning (1.0, 1.1, 1.2, etc.) with Hugging Face integration
30
+
31
+ ## Architecture
32
+
33
+ The application is structured into modular components:
34
+
35
+ - **`app.py`**: Main FastAPI application with data processing pipeline and RLHF endpoints
36
+ - **`data/`**: Storage and persistence modules
37
+ - **`drive_saver.py`**: Google Drive operations and file management
38
+ - **`mongo_saver.py`**: MongoDB operations and data persistence
39
+ - **`firebase_saver.py`**: Firebase operations and data persistence
40
+ - **`train/`**: RLHF training system
41
+ - **`loader.py`**: Load labeled data from Firebase storage with original dataset tracking
42
+ - **`saver.py`**: Save trained models to Hugging Face Hub with semantic versioning
43
+ - **`rlhf.py`**: Main RLHF training pipeline for continuous model improvement
44
+ - **`OBD/`**: OBD-specific modules for data analysis and logging
45
+ - **`utils/`**: Utility modules for model management and data processing
46
+
47
+ ## Quick Start
48
+
49
+ 1. **Install Dependencies**:
50
+ ```bash
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ 2. **Set Environment Variables**:
55
+ - `GDRIVE_CREDENTIALS_JSON`: Google Service Account credentials
56
+ - `FIREBASE_SERVICE_ACCOUNT_JSON`: Firebase connection string
57
+ - `FIREBASE_ADMIN_JSON`: Firebase Admin SDK credentials
58
+ - `HF_TOKEN`: Hugging Face authentication token
59
+ - `HF_MODEL_REPO`: Hugging Face model repository (default: `BinKhoaLe1812/Driver_Behavior_OBD`)
60
+ - `MODEL_DIR`: Local model directory (default: `/app/models/ul`)
61
+
62
+ 3. **Run the Application**:
63
+ ```bash
64
+ uvicorn app:app --reload
65
+ ```
66
+
67
+ 4. **Access the Dashboard**:
68
+ - Web UI: `http://localhost:8000/ui`
69
+ - API Docs: `http://localhost:8000/docs`
70
+
71
+ ## Data Processing Pipeline
72
+
73
+ 1. **Ingestion**: Real-time streaming or bulk CSV upload
74
+ 2. **Cleaning**: Automatic gap detection and KNN imputation
75
+ 3. **Feature Engineering**: Derived metrics and sensor combinations
76
+ 4. **Storage**: Simultaneous save to Google Drive, Firebase, and MongoDB
77
+ 5. **Driver Behavior Classification**: XGBoost model prediction on processed data
78
+ 6. **RLHF Training**: Continuous model improvement through human feedback
79
+ 7. **Visualization**: Correlation analysis and trend plots
80
+
81
+ ## API Endpoints
82
+
83
+ ### Data Ingestion
84
+ - `POST /ingest`: Stream OBD data
85
+ - `POST /upload-csv/`: Bulk CSV upload
86
+
87
+ ### Data Retrieval
88
+ - `GET /download/{filename}`: Download cleaned CSV
89
+ - `GET /events`: Get processing status
90
+
91
+ ### MongoDB Operations
92
+ - `GET /mongo/status`: Check MongoDB connection
93
+ - `GET /mongo/sessions`: Get data session summaries
94
+ - `GET /mongo/query`: Query data with filters
95
+ - `POST /mongo/save-csv`: Direct CSV to MongoDB
96
+
97
+ ### RLHF Training System
98
+ - `POST /rlhf/train`: Trigger RLHF training session
99
+ - `GET /rlhf/status`: Get RLHF system status and available labeled data
100
+ - `GET /rlhf/trained-datasets`: List datasets already used for training
101
+
102
+
103
+ ### Firebase Storage
104
+ - Structured data storage with automatic versioning
105
+ - **`skyledge/raw/`**: Original OBD data files
106
+ - **`skyledge/processed/`**: Cleaned and processed data
107
+ - **`skyledge/labeled/`**: Human-labeled data for RLHF training
108
+ - **`skyledge/labeled/trained.txt`**: Tracks processed datasets to avoid retraining
109
+
110
+ ### Hugging Face Hub
111
+ - **Model Repository**: `BinKhoaLe1812/Driver_Behavior_OBD`
112
+ - **Semantic Versioning**: v1.0, v1.1, v1.2, ..., v2.0, etc.
113
+ - **Model Components**: XGBoost model, label encoder, scaler
114
+ - **Metadata**: Training logs, performance metrics, dataset information
115
+
116
+ ## RLHF Training System
117
+
118
+ ### Overview
119
+ The Reinforcement Learning from Human Feedback (RLHF) system enables continuous improvement of the driver behavior classification model through human-labeled data.
120
+
121
+ ### Key Features
122
+ - **Original Dataset Tracking**: Automatically links labeled data to original datasets
123
+ - **Preference Learning**: Learns from differences between model predictions and human labels
124
+ - **Semantic Versioning**: Automatic model versioning (1.0 → 1.1 → 1.2 → 2.0)
125
+ - **Hugging Face Integration**: Saves models to HF Hub with metadata
126
+ - **Training Tracking**: Prevents retraining on the same datasets
127
+
128
+ ### Usage Examples
129
+
130
+ #### Trigger RLHF Training
131
+ ```bash
132
+ curl -X POST "http://localhost:8000/rlhf/train" \
133
+ -H "Content-Type: application/json" \
134
+ -d '{
135
+ "max_datasets": 5,
136
+ "force_retrain": false
137
+ }'
138
+ ```
139
+
140
+ #### Check Training Status
141
+ ```bash
142
+ curl -X GET "http://localhost:8000/rlhf/status"
143
+ ```
144
+
145
+ #### List Trained Datasets
146
+ ```bash
147
+ curl -X GET "http://localhost:8000/rlhf/trained-datasets"
148
+ ```
149
+
150
+ ### Data Flow
151
+ 1. **Human Labeling**: Data labeled and stored in `skyledge/labeled/`
152
+ 2. **Filename Convention**: `001_raw-002_2025-09-19-labelled.csv`
153
+ 3. **Original Dataset**: Automatically loads `skyledge/raw/002_2025-09-19-raw.csv`
154
+ 4. **RLHF Training**: Compares model predictions vs human labels
155
+ 5. **Model Update**: Trains new model with preference learning
156
+ 6. **Versioning**: Saves as v1.0, v1.1, etc. to Hugging Face Hub
157
+
158
+ ## Documentation
159
+
160
+ - **MongoDB Setup**: See `MONGODB_SETUP.md` for detailed configuration
161
+ - **Google Drive Setup**: See `GOOGLE_DRIVE_SETUP.md` for configuration
162
+ - **RLHF Training**: See `train/README.md` for detailed RLHF documentation
163
+ - **API Reference**: Interactive docs at `/docs` endpoint
164
+ - **Code Structure**: Modular design for easy maintenance
165
+
166
+ ## Development
167
+
168
+ The codebase follows clean architecture principles:
169
+ - **Separation of concerns**: Between storage, processing, API, and ML layers
170
+ - **Comprehensive error handling**: Graceful fallbacks for service unavailability
171
+ - **Type hints and documentation**: Full type annotations and docstrings
172
+ - **Modular design**: Easy to extend and maintain
173
+ - **RLHF Integration**: Seamless integration of machine learning with data processing
174
+ - **Version control**: Semantic versioning for model artifacts
175
+ - **Testing**: Comprehensive test coverage for all components
176
+
177
+ ## Model Management
178
+
179
+ ### Driver Behavior Classification
180
+ - **Model Type**: XGBoost Classifier
181
+ - **Labels**: Aggressive, Normal, Conservative
182
+ - **Features**: OBD sensor data (speed, RPM, throttle, etc.)
183
+ - **Training**: RLHF with human feedback integration
184
+
185
+ ### Model Artifacts
186
+ - **XGBoost Model**: `xgb_drivestyle_ul.pkl`
187
+ - **Label Encoder**: `label_encoder_ul.pkl`
188
+ - **Feature Scaler**: `scaler_ul.pkl`
189
+ - **Metadata**: Training logs and performance metrics
190
+
191
+ ### Versioning Strategy
192
+ - **Semantic Versioning**: 1.0 → 1.1 → 1.2 → 2.0
193
+ - **Automatic Detection**: Checks existing versions in HF repo
194
+ - **Fallback**: Timestamp-based versioning if HF unavailable
195
+ - **Local Backup**: Saves to local `/app/models/ul/v{version}/`
196
+
197
+ ## License
198
+
199
+ Apache 2.0 License
app.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Access: https://binkhoale1812-obd-logger.hf.space/ui
2
+
3
+
4
+ # ───────────── Installation ─────────────
5
+ # Router
6
+ from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException
7
+ from fastapi.responses import FileResponse, HTMLResponse
8
+ from fastapi.staticfiles import StaticFiles
9
+ from fastapi.templating import Jinja2Templates
10
+ from fastapi.requests import Request
11
+ from pydantic import BaseModel
12
+ # ML/DL
13
+ import pandas as pd
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
+ from sklearn.preprocessing import MinMaxScaler
18
+ from sklearn.impute import KNNImputer
19
+ # Utils
20
+ import os, datetime, json, logging, re
21
+ from datetime import timedelta
22
+ import pathlib
23
+
24
+ # Drive
25
+ from data.drive_saver import DriveSaver, get_drive_service, upload_to_folder
26
+
27
+ # Database
28
+ from data.mongo_saver import MongoSaver, save_csv_to_mongo, save_dataframe_to_mongo, MONGODB_AVAILABLE
29
+ from data.firebase_saver import FirebaseSaver, save_csv_increment, save_dataframe_increment
30
+
31
+ # UL Model
32
+ from utils.ul_label import ULLabeler
33
+
34
+ # RLHF Training
35
+ from train import RLHFTrainer
36
+
37
+ # ───────────── Logging Setup ─────────────
38
+ logger = logging.getLogger("obd-logger")
39
+ logger.setLevel(logging.INFO)
40
+ fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
41
+ handler = logging.StreamHandler()
42
+ handler.setFormatter(fmt)
43
+ logger.addHandler(handler)
44
+
45
+
46
+ # ───────────── FastAPI Init ─────────────
47
+ app = FastAPI(title="OBD-II Logging & Processing API")
48
+
49
+
50
+ # ───────────── Directory Paths ─────────────
51
+ APP_ROOT = pathlib.Path(__file__).parent.resolve() # Absolute base dir
52
+ BASE_DIR = os.path.join(APP_ROOT, './cache/obd_data')
53
+ CLEANED_DIR = os.path.join(BASE_DIR, "cleaned")
54
+ PLOT_DIR = os.path.join(BASE_DIR, "plots")
55
+ RAW_CSV = os.path.join(BASE_DIR, "raw_logs.csv")
56
+ os.makedirs(BASE_DIR, exist_ok=True)
57
+ os.makedirs(CLEANED_DIR, exist_ok=True)
58
+ os.makedirs(PLOT_DIR, exist_ok=True)
59
+
60
+ DRIVE_STYLE = [] # latest UL predictions (string labels) — overwritten each run
61
+
62
+ # Init temp empty file
63
+ if not os.path.exists(RAW_CSV):
64
+ pd.DataFrame(columns=["timestamp", "driving_style"]).to_csv(RAW_CSV, index=False)
65
+
66
+ PIPELINE_EVENTS = {}
67
+
68
+
69
+ # ───────────── Drive & Database Services ─────────────
70
+ # Initialize services
71
+ drive_saver = DriveSaver()
72
+ mongo_saver = MongoSaver()
73
+ firebase_saver = FirebaseSaver()
74
+
75
+ # ───────────── Model Download on Startup ─────────────
76
+ @app.on_event("startup")
77
+ async def startup_event():
78
+ """Download models on app startup"""
79
+ try:
80
+ logger.info("🚀 Starting model download...")
81
+ from utils.download import download_latest_models
82
+
83
+ # Load .env file if it exists
84
+ env_path = pathlib.Path(".env")
85
+ if env_path.exists():
86
+ logger.info("📄 Loading .env file...")
87
+ with open(env_path, 'r') as f:
88
+ for line in f:
89
+ line = line.strip()
90
+ if line and not line.startswith('#') and '=' in line:
91
+ key, value = line.split('=', 1)
92
+ os.environ[key] = value
93
+
94
+ # Download models
95
+ success = download_latest_models()
96
+ if success:
97
+ logger.info("✅ Models downloaded successfully on startup")
98
+ else:
99
+ logger.warning("⚠️ Model download failed on startup - some features may not work")
100
+
101
+ except Exception as e:
102
+ logger.error(f"❌ Startup model download failed: {e}")
103
+ logger.warning("⚠️ Continuing without models - some features may not work")
104
+
105
+ # ───────────── Render Dashboard UI ──────────────
106
+ app.mount("/static", StaticFiles(directory="static"), name="static")
107
+ app.mount("/plots", StaticFiles(directory=str(PLOT_DIR)), name="plots")
108
+ templates = Jinja2Templates(directory="static")
109
+ # Endpoint
110
+ @app.get("/ui", response_class=HTMLResponse)
111
+ def dashboard(request: Request):
112
+ return templates.TemplateResponse("index.html", {"request": request})
113
+
114
+
115
+ # ───────────── Streamed Entry Ingest ─────────────
116
+ class OBDEntry(BaseModel):
117
+ timestamp: str
118
+ driving_style: str
119
+ data: dict
120
+ status: str = None # Optional for control signal (start/end streaming)
121
+
122
+ # Direct centralized timestamp format
123
+ def normalize_timestamp(ts):
124
+ return ts.replace(":", "-").replace(".", "-").replace(" ", "T").replace("/", "-")
125
+
126
+ # Real time endpoint
127
+ @app.post("/ingest")
128
+ def ingest(entry: OBDEntry, background_tasks: BackgroundTasks):
129
+ norm_ts = normalize_timestamp(entry.timestamp)
130
+ logger.info(f"Ingest received: {norm_ts} | Status: {entry.status}")
131
+ # Start logging
132
+ if entry.status == "start":
133
+ PIPELINE_EVENTS[norm_ts] = {"status": "started", "time": norm_ts}
134
+ return {"status": "started"}
135
+ # End logging, start processing
136
+ if entry.status == "end":
137
+ background_tasks.add_task(process_data, norm_ts)
138
+ return {"status": "processed"}
139
+ # Normal row append
140
+ try:
141
+ df = pd.read_csv(RAW_CSV)
142
+ row = {"timestamp": norm_ts, "driving_style": entry.driving_style}
143
+ row.update(entry.data)
144
+ df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
145
+ df.to_csv(RAW_CSV, index=False)
146
+ return {"status": "row appended"}
147
+ except Exception as e:
148
+ logger.error(f"Streaming ingest failed: {e}")
149
+ raise HTTPException(status_code=500, detail="Ingest error")
150
+
151
+
152
+ # ───────────── Bulk CSV Upload ───────────────────
153
+ @app.post("/upload-csv/")
154
+ async def upload_csv(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
155
+ ts = datetime.datetime.now().isoformat()
156
+ norm_ts = normalize_timestamp(ts)
157
+ path = os.path.join(BASE_DIR, file.filename)
158
+ PIPELINE_EVENTS[norm_ts] = {"status": "started", "time": norm_ts}
159
+ with open(path, "wb") as f:
160
+ f.write(await file.read())
161
+ logger.info(f"CSV uploaded: {path}")
162
+ background_tasks.add_task(process_uploaded_csv, path, norm_ts)
163
+ return {"status": "processing started", "file": file.filename}
164
+
165
+
166
+ # ───────────── Data Processing ──────────────────
167
+ # Bulk CSV
168
+ def process_uploaded_csv(path, norm_ts):
169
+ try:
170
+ df = pd.read_csv(path, parse_dates=["timestamp"])
171
+ PIPELINE_EVENTS[norm_ts] = {
172
+ "status": "processed",
173
+ "time": norm_ts
174
+ }
175
+ _process_and_save(df, norm_ts)
176
+ except Exception as e:
177
+ logger.error(f"CSV processing failed: {e}")
178
+
179
+ # Process streaming
180
+ def process_data(norm_ts):
181
+ try:
182
+ df = pd.read_csv(RAW_CSV, parse_dates=["timestamp"])
183
+ PIPELINE_EVENTS[norm_ts] = {
184
+ "status": "processed",
185
+ "time": norm_ts
186
+ }
187
+ _process_and_save(df, norm_ts)
188
+ except Exception as e:
189
+ logger.error(f"Streamed data processing failed: {e}")
190
+
191
+
192
+ # All processing pipeline
193
+ def _process_and_save(df, norm_ts):
194
+ """
195
+ Gap-aware, multi-sensor backfill for OBD-II streams with unknown cadence.
196
+ - Infers sampling interval from data (robust).
197
+ - Inserts placeholder rows for gaps using the inferred interval.
198
+ - Flags only corrupted values (NaN/inf/sentinels); does NOT trim 'extreme but plausible' outliers.
199
+ - Backfills ALL numeric sensors with KNNImputer (+ time as a feature).
200
+ - Keeps your plotting, Drive upload, and PIPELINE_EVENTS wiring intact.
201
+ """
202
+ logger.info("🔧 Cleaning started (auto-interval, KNN for all sensors)")
203
+
204
+ # ----------------------- helpers (scoped locally) -----------------------
205
+ protected_cols = {"timestamp", "driving_style"}
206
+ SENTINELS = {-22, -40, 255}
207
+
208
+ def _to_dt(_df: pd.DataFrame) -> pd.DataFrame:
209
+ _df = _df.copy()
210
+ _df["timestamp"] = pd.to_datetime(_df["timestamp"], errors="coerce", utc=True)
211
+ _df = _df.dropna(subset=["timestamp"]).sort_values("timestamp").reset_index(drop=True)
212
+ # drop exact duplicate timestamps (keep first)
213
+ _df = _df[~_df["timestamp"].duplicated(keep="first")].reset_index(drop=True)
214
+ return _df
215
+
216
+ def _drop_dead_weight(_df: pd.DataFrame) -> pd.DataFrame:
217
+ _df = _df.copy()
218
+ # drop all-NaN or constant columns (except protected)
219
+ drop_cols = [c for c in _df.columns
220
+ if c not in protected_cols and (_df[c].nunique(dropna=True) <= 1 or _df[c].isna().all())]
221
+ if drop_cols:
222
+ _df.drop(columns=drop_cols, inplace=True, errors="ignore")
223
+ # drop duplicate columns
224
+ _df = _df.loc[:, ~_df.T.duplicated()]
225
+ # drop duplicate rows
226
+ _df.drop_duplicates(inplace=True)
227
+ return _df
228
+
229
+ def _normalize_corruption(_df: pd.DataFrame) -> pd.DataFrame:
230
+ _df = _df.copy()
231
+ # normalize obvious corruptions: NaN/inf/sentinels → NaN
232
+ _df.replace(list(SENTINELS), np.nan, inplace=True)
233
+ num_cols = _df.select_dtypes(include=[np.number]).columns
234
+ for c in num_cols:
235
+ s = _df[c]
236
+ s = s.astype(float)
237
+ s[~np.isfinite(s)] = np.nan
238
+ _df[c] = s
239
+ return _df
240
+
241
+ def _light_row_col_filters(_df: pd.DataFrame) -> pd.DataFrame:
242
+ _df = _df.copy()
243
+ # keep rows with <=80% NaN (excluding timestamp)
244
+ if "timestamp" in _df.columns and _df.shape[1] > 1:
245
+ keep = _df.drop(columns=["timestamp"]).isna().mean(axis=1) <= 0.8
246
+ _df = _df[keep]
247
+ # prune columns with >80% NaN (except protected)
248
+ na_frac = _df.isna().mean(numeric_only=False)
249
+ high_na = [c for c in na_frac.index if na_frac[c] > 0.8 and c not in protected_cols]
250
+ if high_na:
251
+ _df.drop(columns=high_na, inplace=True, errors="ignore")
252
+ # keep rows that have >1 observed value across non-timestamp columns
253
+ if "timestamp" in _df.columns and _df.shape[1] > 1:
254
+ valid = _df.drop(columns=["timestamp"]).notna().sum(axis=1) > 1
255
+ _df = _df[valid]
256
+ return _df
257
+
258
+ def _infer_base_interval_seconds(ts: pd.Series) -> float:
259
+ """
260
+ Robustly infer base cadence from timestamp diffs.
261
+ Strategy:
262
+ - take positive diffs
263
+ - winsorize to 5–95% to reduce impact of long gaps
264
+ - compute a 'rounded mode' on 10ms grid; fall back to median if needed
265
+ """
266
+ if ts.size < 2:
267
+ return 1.0 # fallback
268
+ diffs = ts.sort_values().diff().dropna().dt.total_seconds()
269
+ diffs = diffs[diffs > 0]
270
+ if diffs.empty:
271
+ return 1.0
272
+ q05, q95 = diffs.quantile([0.05, 0.95])
273
+ core = diffs[(diffs >= q05) & (diffs <= q95)]
274
+ if core.empty:
275
+ core = diffs
276
+ # round to 10ms and take the most frequent bin
277
+ rounded = (core / 0.01).round() * 0.01
278
+ mode = rounded.mode()
279
+ if not mode.empty:
280
+ est = float(mode.iloc[0])
281
+ else:
282
+ est = float(core.median())
283
+ # guardrails
284
+ if est <= 0:
285
+ est = float(core.median())
286
+ logger.info(f"⏱️ Inferred base interval ≈ {est:.3f}s")
287
+ return est
288
+
289
+ def _insert_time_gaps(_df: pd.DataFrame, base_sec: float) -> pd.DataFrame:
290
+ """
291
+ Insert placeholder rows at multiples of inferred base_sec when gaps exceed ~1.5× base.
292
+ All numeric columns are NaN in inserted rows; non-numeric are forward-filled (except protected).
293
+ """
294
+ if _df.empty:
295
+ return _df
296
+ _df = _df.copy()
297
+ _df = _to_dt(_df)
298
+ expected = timedelta(seconds=base_sec)
299
+ # tolerance ~ half interval to avoid jittery inserts
300
+ tol = timedelta(seconds=0.5 * base_sec)
301
+ # Normalize data
302
+ num_cols = _df.select_dtypes(include=[np.number]).columns.tolist()
303
+ non_num_cols = [c for c in _df.columns if c not in num_cols]
304
+ # Missing detection on interval expectation
305
+ rows = [_df.iloc[0].copy()]
306
+ for i in range(1, len(_df)):
307
+ prev = _df.iloc[i - 1]
308
+ curr = _df.iloc[i]
309
+ dt = curr["timestamp"] - prev["timestamp"]
310
+ if dt > expected * 1.5 + tol:
311
+ n_missing = int(round(dt / expected)) - 1
312
+ if n_missing > 0:
313
+ for j in range(1, n_missing + 1):
314
+ gap = prev.copy()
315
+ gap["timestamp"] = prev["timestamp"] + j * expected
316
+ # numeric sensors left as NaN to be imputed
317
+ for c in num_cols:
318
+ if c not in protected_cols:
319
+ gap[c] = np.nan
320
+ # for non-numeric, keep last known (except protected)
321
+ for c in non_num_cols:
322
+ if c not in protected_cols:
323
+ gap[c] = prev[c]
324
+ rows.append(gap)
325
+ rows.append(curr.copy())
326
+ # Sorting
327
+ out = pd.DataFrame(rows).sort_values("timestamp").reset_index(drop=True)
328
+ return out
329
+
330
+ def _knn_impute_all(_df: pd.DataFrame) -> pd.DataFrame:
331
+ """
332
+ Backfill ALL numeric sensors jointly with KNN, using time (ts_sec) as an additional feature.
333
+ """
334
+ _df = _df.copy()
335
+ _df["ts_sec"] = (_df["timestamp"] - _df["timestamp"].min()).dt.total_seconds()
336
+ # Normalize data
337
+ num_cols = _df.select_dtypes(include=[np.number]).columns.tolist()
338
+ # ensure ts_sec included
339
+ if "ts_sec" not in num_cols:
340
+ num_cols = num_cols + ["ts_sec"]
341
+ # Build imputation frame and remember order
342
+ X = _df[num_cols].copy()
343
+ non_missing_rows = X.dropna().shape[0]
344
+ k = min(5, max(1, non_missing_rows))
345
+ logger.info(f"🤝 KNNImputer n_neighbors={k} on {len(num_cols)} features")
346
+ # Impute and backfill data using KNN
347
+ imputer = KNNImputer(n_neighbors=k)
348
+ X_imp = imputer.fit_transform(X)
349
+ X_imp = pd.DataFrame(X_imp, columns=num_cols, index=_df.index)
350
+ # Write back (excluding ts_sec)
351
+ for c in num_cols:
352
+ if c == "ts_sec":
353
+ continue
354
+ _df[c] = X_imp[c]
355
+
356
+ _df.drop(columns=["ts_sec"], inplace=True)
357
+ return _df
358
+
359
+ # Copy data from selective sensor types for Feature Engineering
360
+ def _feature_engineering(_df: pd.DataFrame) -> pd.DataFrame:
361
+ _df = _df.copy()
362
+ if {"ENGINE_LOAD", "ABSOLUTE_LOAD"}.issubset(_df.columns):
363
+ _df["AVG_ENGINE_LOAD"] = _df[["ENGINE_LOAD", "ABSOLUTE_LOAD"]].mean(axis=1)
364
+ if {"INTAKE_TEMP", "OIL_TEMP", "COOLANT_TEMP"}.issubset(_df.columns):
365
+ _df["TEMP_MEAN"] = _df[["INTAKE_TEMP", "OIL_TEMP", "COOLANT_TEMP"]].mean(axis=1)
366
+ if {"MAF", "RPM"}.issubset(_df.columns):
367
+ _df["AIRFLOW_PER_RPM"] = _df["MAF"] / _df["RPM"].replace(0, np.nan)
368
+ return _df
369
+
370
+ # Apply MinMaxScaler to fit data frame
371
+ def _scale_numeric(_df: pd.DataFrame) -> pd.DataFrame:
372
+ _df = _df.copy()
373
+ num_cols = _df.select_dtypes(include=[np.number]).columns.tolist()
374
+ for c in list(protected_cols):
375
+ if c in num_cols:
376
+ num_cols.remove(c)
377
+ if num_cols:
378
+ scaler = MinMaxScaler()
379
+ _df[num_cols] = scaler.fit_transform(_df[num_cols])
380
+ return _df
381
+
382
+ # Correlation heatmap plotter
383
+ def _plot_corr(_df: pd.DataFrame, _id: str):
384
+ try:
385
+ num = _df.select_dtypes(include=[np.number])
386
+ if num.shape[1] < 2:
387
+ return
388
+ plt.figure(figsize=(12, 10))
389
+ sns.heatmap(num.corr(), annot=True, fmt=".2f", cmap="coolwarm")
390
+ plt.title("Correlation Between Numeric OBD-II Variables")
391
+ plt.tight_layout()
392
+ plt.savefig(os.path.join(PLOT_DIR, f"heatmap_{_id}.png"))
393
+ plt.close()
394
+ except Exception as e:
395
+ logger.error(f"Heatmap generation failed: {e}")
396
+
397
+ # Sensor trend plotter
398
+ def _plot_trend(_df: pd.DataFrame, _id: str):
399
+ try:
400
+ plt.figure(figsize=(15, 6))
401
+ for col in ['RPM', 'ENGINE_LOAD', 'ABSOLUTE_LOAD', 'COOLANT_TEMP',
402
+ 'INTAKE_TEMP', 'OIL_TEMP', 'INTAKE_PRESSURE', 'BAROMETRIC_PRESSURE',
403
+ 'CONTROL_MODULE_VOLTAGE']:
404
+ if col in _df.columns:
405
+ plt.plot(_df.index, _df[col], label=col)
406
+ plt.title("Sensor Trends (Index-Based, No Time Gaps)")
407
+ plt.xlabel("Sample Index")
408
+ plt.ylabel("Sensor Value")
409
+ plt.legend()
410
+ plt.grid(True)
411
+ plt.tight_layout()
412
+ plt.savefig(os.path.join(PLOT_DIR, f"trend_{_id}.png"))
413
+ plt.close()
414
+ except Exception as e:
415
+ logger.error(f"Trend plot failed: {e}")
416
+
417
+ # ----------------------- pipeline -----------------------
418
+ df = df.copy()
419
+ # 0) Basic tidy
420
+ df = _drop_dead_weight(df)
421
+ df = _to_dt(df)
422
+ # 1) Corruption-only normalization (no outlier trimming)
423
+ df = _normalize_corruption(df)
424
+ # 2) Light row/column filtering for extreme sparsity
425
+ df = _light_row_col_filters(df)
426
+ # 3) Auto infer base interval & insert gap rows
427
+ base_sec = _infer_base_interval_seconds(df["timestamp"])
428
+ df = _insert_time_gaps(df, base_sec)
429
+ # 4) KNN backfill all numeric sensors (time-aware)
430
+ df = _knn_impute_all(df)
431
+ # 5) Feature engineering AFTER imputation
432
+ df = _feature_engineering(df)
433
+ # 6) Final sort / index
434
+ df.sort_values("timestamp", inplace=True)
435
+ df.reset_index(drop=True, inplace=True)
436
+ # 7) Scaling after impute (kept from original)
437
+ if not df.select_dtypes(include=["number"]).empty:
438
+ df = _scale_numeric(df)
439
+ # 8) Save
440
+ out_path = os.path.join(CLEANED_DIR, f"cleaned_{norm_ts}.csv")
441
+ df.to_csv(out_path, index=False)
442
+ logger.info(f"✅ Cleaned saved: {out_path}")
443
+ # 9) UL drivestyle predictions
444
+ df_for_persist = df
445
+ labeled_path = None
446
+ try:
447
+ ul = ULLabeler.get()
448
+ preds = ul.predict_df(df)
449
+ # update global DRIVE_STYLE (overwrite if already exists)
450
+ global DRIVE_STYLE
451
+ DRIVE_STYLE = [str(p) for p in preds]
452
+ # write labeled CSV (driving_style column)
453
+ df_labeled = df.copy()
454
+ df_labeled["driving_style"] = DRIVE_STYLE
455
+ labeled_path = os.path.join(CLEANED_DIR, f"cleaned_{norm_ts}_labeled.csv")
456
+ df_labeled.to_csv(labeled_path, index=False)
457
+ df_for_persist = df_labeled
458
+ # Update the global DRIVE_STYLE list
459
+ logger.info(f"✅ UL labels generated ({len(DRIVE_STYLE)}) → {labeled_path}")
460
+ except Exception as e:
461
+ logger.error(f"❌ UL labeling failed: {e}")
462
+ # 10) Plots
463
+ _plot_corr(df, norm_ts)
464
+ _plot_trend(df, norm_ts)
465
+ # 11) Update event
466
+ try:
467
+ PIPELINE_EVENTS[norm_ts]["status"] = "done"
468
+ except Exception:
469
+ pass
470
+ # 12) Upload to Drive
471
+ try:
472
+ if drive_saver.is_service_available():
473
+ if labeled_path and os.path.exists(labeled_path):
474
+ drive_saver.upload_csv_to_drive(labeled_path)
475
+ logger.info("✅ Uploaded labeled to Google Drive")
476
+ else:
477
+ drive_saver.upload_csv_to_drive(out_path)
478
+ logger.info("✅ Uploaded default to Google Drive")
479
+ else:
480
+ logger.warning("⚠️ Google Drive service not available")
481
+ except Exception as e:
482
+ logger.error(f"❌ Drive upload error: {e}")
483
+ # 13) Save to MongoDB
484
+ try:
485
+ if mongo_saver.is_connected():
486
+ # Save the cleaned DataFrame directly to MongoDB
487
+ session_id = f"session_{norm_ts}"
488
+ if mongo_saver.save_dataframe_to_mongo(df_for_persist, session_id):
489
+ logger.info("✅ Saved to MongoDB")
490
+ else:
491
+ logger.warning("⚠️ MongoDB save failed")
492
+ else:
493
+ logger.warning("⚠️ MongoDB not connected")
494
+ except Exception as e:
495
+ logger.error(f"❌ MongoDB save error: {e}")
496
+ # 14) Save to Firebase Storage (incremented NNN_YYYY-MM-DD_processed.csv at fixed path)
497
+ try:
498
+ if firebase_saver and firebase_saver.is_available():
499
+ # Choose the final artifact to persist
500
+ if labeled_path and os.path.exists(labeled_path):
501
+ target_path = labeled_path
502
+ else:
503
+ target_path = out_path
504
+ # Optional: use the acquisition date if norm_ts starts with YYYY-MM-DD, else let saver use AUS/Melbourne "today"
505
+ date_str = None
506
+ try:
507
+ date_str = str(norm_ts)[:10] if norm_ts and len(str(norm_ts)) >= 10 else None
508
+ except Exception:
509
+ date_str = None
510
+ # Upload with auto-incremented name: NNN_YYYY-MM-DD_processed.csv under skyledge/processed
511
+ gs_url = firebase_saver.upload_file_with_increment(target_path, date_str=date_str)
512
+ # Save to Firebase Storage (incremented NNN_YYYY-MM-DD_processed.csv at fixed path)
513
+ if gs_url:
514
+ logger.info(f"✅ Saved to Firebase Storage: {gs_url}")
515
+ else:
516
+ logger.warning("⚠️ Firebase Storage upload returned empty URL")
517
+ else:
518
+ logger.warning("⚠️ Firebase Storage not available")
519
+ except Exception as e:
520
+ logger.error(f"❌ Firebase Storage save error: {e}")
521
+
522
+
523
+
524
+ # ───────────── Health Check ──────────────────────
525
+ @app.get("/health")
526
+ def health():
527
+ return {"status": "ok"}
528
+
529
+ @app.get("/models/status")
530
+ def models_status():
531
+ """Check if models are loaded and available"""
532
+ try:
533
+ model_dir = pathlib.Path(os.getenv("MODEL_DIR", "/app/models/ul"))
534
+ required_files = ["label_encoder_ul.pkl", "scaler_ul.pkl", "xgb_drivestyle_ul.pkl"]
535
+
536
+ available_files = []
537
+ missing_files = []
538
+
539
+ for file in required_files:
540
+ file_path = model_dir / file
541
+ if file_path.exists():
542
+ available_files.append(file)
543
+ else:
544
+ missing_files.append(file)
545
+
546
+ status = "ready" if len(available_files) == len(required_files) else "loading"
547
+
548
+ return {
549
+ "status": status,
550
+ "model_directory": str(model_dir),
551
+ "available_files": available_files,
552
+ "missing_files": missing_files,
553
+ "total_files": len(required_files),
554
+ "loaded_files": len(available_files)
555
+ }
556
+ except Exception as e:
557
+ return {
558
+ "status": "error",
559
+ "error": str(e),
560
+ "timestamp": datetime.now().isoformat()
561
+ }
562
+
563
+
564
+ # ─────── Send status to frontend ─────────────────
565
+ @app.get("/events")
566
+ def get_events():
567
+ return PIPELINE_EVENTS
568
+
569
+
570
+ # ────── Delete event from dashboard ──────────────
571
+ @app.delete("/events/remove/{timestamp}")
572
+ def remove_event(timestamp: str):
573
+ if timestamp in PIPELINE_EVENTS:
574
+ del PIPELINE_EVENTS[timestamp]
575
+ return {"status": "deleted"}
576
+
577
+
578
+ # ───────────── Download Cleaned ──────────────────
579
+ @app.get("/download/{filename}")
580
+ def download_file(filename: str):
581
+ path = os.path.join(CLEANED_DIR, filename)
582
+ if not os.path.exists(path):
583
+ raise HTTPException(status_code=404, detail="Not found")
584
+ return FileResponse(path, media_type='text/csv', filename=filename)
585
+
586
+
587
+ # ───────────── MongoDB Operations ──────────────────
588
+ @app.get("/mongo/status")
589
+ def mongo_status():
590
+ """Check MongoDB connection status"""
591
+ return {
592
+ "connected": mongo_saver.is_connected(),
593
+ "available": MONGODB_AVAILABLE if 'MONGODB_AVAILABLE' in globals() else False
594
+ }
595
+
596
+
597
+ @app.get("/mongo/sessions")
598
+ def get_mongo_sessions():
599
+ """Get summary of all MongoDB sessions"""
600
+ if not mongo_saver.is_connected():
601
+ raise HTTPException(status_code=503, detail="MongoDB not connected")
602
+
603
+ sessions = mongo_saver.get_session_summary()
604
+ return {"sessions": sessions}
605
+
606
+
607
+ @app.get("/mongo/query")
608
+ def query_mongo_data(
609
+ session_id: str = None,
610
+ driving_style: str = None,
611
+ start_time: str = None,
612
+ end_time: str = None,
613
+ limit: int = 1000
614
+ ):
615
+ """Query data from MongoDB with filters"""
616
+ if not mongo_saver.is_connected():
617
+ raise HTTPException(status_code=503, detail="MongoDB not connected")
618
+
619
+ # Parse datetime strings if provided
620
+ start_dt = None
621
+ end_dt = None
622
+
623
+ if start_time:
624
+ try:
625
+ start_dt = pd.to_datetime(start_time)
626
+ except Exception:
627
+ raise HTTPException(status_code=400, detail="Invalid start_time format")
628
+
629
+ if end_time:
630
+ try:
631
+ end_dt = pd.to_datetime(end_time)
632
+ except Exception:
633
+ raise HTTPException(status_code=400, detail="Invalid end_time format")
634
+
635
+ results = mongo_saver.query_data(
636
+ session_id=session_id,
637
+ driving_style=driving_style,
638
+ start_time=start_dt,
639
+ end_time=end_dt,
640
+ limit=limit
641
+ )
642
+
643
+ return {"results": results, "count": len(results)}
644
+
645
+
646
+ @app.post("/mongo/save-csv")
647
+ async def save_csv_to_mongo_endpoint(
648
+ file: UploadFile = File(...),
649
+ session_id: str = None
650
+ ):
651
+ """Save uploaded CSV directly to MongoDB"""
652
+ if not mongo_saver.is_connected():
653
+ raise HTTPException(status_code=503, detail="MongoDB not connected")
654
+
655
+ try:
656
+ # Save uploaded file temporarily
657
+ temp_path = os.path.join(BASE_DIR, f"temp_{file.filename}")
658
+ with open(temp_path, "wb") as f:
659
+ f.write(await file.read())
660
+
661
+ # Save to MongoDB
662
+ success = mongo_saver.save_csv_to_mongo(temp_path, session_id)
663
+
664
+ # Clean up temp file
665
+ if os.path.exists(temp_path):
666
+ os.remove(temp_path)
667
+
668
+ if success:
669
+ return {"status": "success", "message": "CSV saved to MongoDB"}
670
+ else:
671
+ raise HTTPException(status_code=500, detail="Failed to save to MongoDB")
672
+
673
+ except Exception as e:
674
+ logger.error(f"CSV to MongoDB save failed: {e}")
675
+ raise HTTPException(status_code=500, detail=f"Save failed: {str(e)}")
676
+
677
+
678
+ # ───────────── RLHF Training Endpoints ─────────────
679
+
680
+ class RLHFTrainingRequest(BaseModel):
681
+ max_datasets: int = 10
682
+ force_retrain: bool = False
683
+
684
+ class RLHFTrainingResponse(BaseModel):
685
+ status: str
686
+ model_version: str = None
687
+ datasets_processed: int = 0
688
+ samples_processed: int = 0
689
+ performance_metrics: dict = None
690
+ error: str = None
691
+ timestamp: str = None
692
+
693
+ @app.post("/rlhf/train", response_model=RLHFTrainingResponse)
694
+ async def trigger_rlhf_training(
695
+ request: RLHFTrainingRequest,
696
+ background_tasks: BackgroundTasks
697
+ ):
698
+ """
699
+ Trigger RLHF (Reinforcement Learning from Human Feedback) training session.
700
+
701
+ This endpoint:
702
+ 1. Loads human-labeled data from Firebase storage (skyledge/labeled)
703
+ 2. Combines it with existing model predictions for RLHF
704
+ 3. Retrains the XGBoost model with the combined dataset
705
+ 4. Saves the new model to Hugging Face Hub
706
+ """
707
+ try:
708
+ logger.info(f"🚀 RLHF training requested with max_datasets={request.max_datasets}")
709
+
710
+ # Initialize trainer
711
+ trainer = RLHFTrainer()
712
+
713
+ # Run training
714
+ result = trainer.train(max_datasets=request.max_datasets)
715
+
716
+ if result["status"] == "success":
717
+ logger.info(f"✅ RLHF training completed: v{result['model_version']}")
718
+ return RLHFTrainingResponse(
719
+ status="success",
720
+ model_version=result["model_version"],
721
+ datasets_processed=result["datasets_processed"],
722
+ samples_processed=result["samples_processed"],
723
+ performance_metrics=result["performance_metrics"],
724
+ timestamp=datetime.now().isoformat()
725
+ )
726
+ elif result["status"] == "no_data":
727
+ logger.info("ℹ️ No new data available for RLHF training")
728
+ return RLHFTrainingResponse(
729
+ status="no_data",
730
+ timestamp=datetime.now().isoformat()
731
+ )
732
+ else:
733
+ logger.error(f"❌ RLHF training failed: {result.get('error', 'Unknown error')}")
734
+ return RLHFTrainingResponse(
735
+ status="error",
736
+ error=result.get("error", "Unknown error"),
737
+ timestamp=datetime.now().isoformat()
738
+ )
739
+
740
+ except Exception as e:
741
+ logger.error(f"❌ RLHF training endpoint failed: {e}")
742
+ raise HTTPException(
743
+ status_code=500,
744
+ detail=f"RLHF training failed: {str(e)}"
745
+ )
746
+
747
+ @app.get("/rlhf/status")
748
+ async def get_rlhf_status():
749
+ """
750
+ Get status of RLHF training system and available labeled data.
751
+ """
752
+ try:
753
+ from train import LabeledDataLoader
754
+
755
+ loader = LabeledDataLoader()
756
+ datasets = loader.list_labeled_datasets()
757
+
758
+ return {
759
+ "status": "available",
760
+ "labeled_datasets_count": len(datasets),
761
+ "datasets": [
762
+ {
763
+ "name": d["name"],
764
+ "size": d["size"],
765
+ "created": d["created"]
766
+ } for d in datasets[:10] # Limit to first 10 for response size
767
+ ],
768
+ "firebase_bucket": "skyledge-36b56.firebasestorage.app",
769
+ "labeled_path": "skyledge/labeled",
770
+ "timestamp": datetime.now().isoformat()
771
+ }
772
+
773
+ except Exception as e:
774
+ logger.error(f"❌ RLHF status check failed: {e}")
775
+ raise HTTPException(
776
+ status_code=500,
777
+ detail=f"Status check failed: {str(e)}"
778
+ )
779
+
780
+ @app.get("/rlhf/trained-datasets")
781
+ async def get_trained_datasets():
782
+ """
783
+ Get list of datasets that have already been used for training.
784
+ """
785
+ try:
786
+ from train import LabeledDataLoader
787
+
788
+ loader = LabeledDataLoader()
789
+ trained_datasets = loader._get_trained_datasets()
790
+
791
+ return {
792
+ "trained_datasets_count": len(trained_datasets),
793
+ "trained_datasets": trained_datasets,
794
+ "timestamp": datetime.now().isoformat()
795
+ }
796
+
797
+ except Exception as e:
798
+ logger.error(f"❌ Failed to get trained datasets: {e}")
799
+ raise HTTPException(
800
+ status_code=500,
801
+ detail=f"Failed to get trained datasets: {str(e)}"
802
+ )
data.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "timestamp": "2025-05-15T10:00:00",
3
+ "driving_style": "aggressive",
4
+ "data": {
5
+ "RPM": 3200,
6
+ "THROTTLE_POS": 75,
7
+ "SPEED": 110,
8
+ "FUEL_PRESSURE": 290,
9
+ "ENGINE_LOAD": 45,
10
+ "COOLANT_TEMP": 85,
11
+ "INTAKE_TEMP": 30,
12
+ "TIMING_ADVANCE": 10,
13
+ "MAF": 12.5,
14
+ "INTAKE_PRESSURE": 28,
15
+ "SHORT_FUEL_TRIM_1": 3.1,
16
+ "LONG_FUEL_TRIM_1": 6.2,
17
+ "SHORT_FUEL_TRIM_2": 2.5,
18
+ "LONG_FUEL_TRIM_2": 5.0,
19
+ "COMMANDED_EQUIV_RATIO": 1.0,
20
+ "O2_B1S2": 0.74,
21
+ "O2_B2S2": 0.68,
22
+ "O2_S1_WR_VOLTAGE": 0.85,
23
+ "COMMANDED_EGR": 10
24
+ }
25
+ }
26
+
data/drive_saver.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Google Drive Operations for OBD Logger
2
+ # Handles authentication and file uploads to Google Drive
3
+
4
+ import os
5
+ import json
6
+ import logging
7
+ from google.oauth2 import service_account
8
+ from googleapiclient.discovery import build
9
+ from googleapiclient.http import MediaFileUpload
10
+
11
+ # ───────────── Logging Setup ─────────────
12
+ logger = logging.getLogger("drive-saver")
13
+ logger.setLevel(logging.INFO)
14
+ fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
15
+ handler = logging.StreamHandler()
16
+ handler.setFormatter(fmt)
17
+ logger.addHandler(handler)
18
+
19
+
20
+ class DriveSaver:
21
+ """Handles Google Drive operations for saving OBD data"""
22
+
23
+ def __init__(self):
24
+ self.service = None
25
+ self.folder_id = "1r-wefqKbK9k9BeYDW1hXRbx4B-0Fvj5P" # Default folder ID
26
+ self._initialize_service()
27
+
28
+ def _initialize_service(self):
29
+ """Initialize Google Drive service with credentials"""
30
+ try:
31
+ creds_dict = json.loads(os.getenv("GDRIVE_CREDENTIALS_JSON"))
32
+ creds = service_account.Credentials.from_service_account_info(
33
+ creds_dict,
34
+ scopes=["https://www.googleapis.com/auth/drive"]
35
+ )
36
+ self.service = build("drive", "v3", credentials=creds)
37
+ logger.info("✅ Google Drive service initialized successfully")
38
+ except Exception as e:
39
+ logger.error(f"❌ Drive initialization failed: {e}")
40
+ self.service = None
41
+
42
+ def upload_csv_to_drive(self, file_path: str, folder_id: str = None) -> bool:
43
+ """
44
+ Upload a CSV file to Google Drive
45
+
46
+ Args:
47
+ file_path (str): Path to the CSV file to upload
48
+ folder_id (str, optional): Target folder ID. Uses default if not specified.
49
+
50
+ Returns:
51
+ bool: True if upload successful, False otherwise
52
+ """
53
+ if not self.service:
54
+ logger.error("❌ Drive service not initialized")
55
+ return False
56
+
57
+ target_folder = folder_id or self.folder_id
58
+
59
+ try:
60
+ file_name = os.path.basename(file_path)
61
+ media = MediaFileUpload(file_path, mimetype='text/csv')
62
+ metadata = {"name": file_name, "parents": [target_folder]}
63
+
64
+ result = self.service.files().create(
65
+ body=metadata,
66
+ media_body=media,
67
+ fields="id"
68
+ ).execute()
69
+
70
+ logger.info(f"✅ File uploaded to Drive successfully: {file_name} (ID: {result.get('id')})")
71
+ return True
72
+
73
+ except Exception as e:
74
+ logger.error(f"❌ Drive upload failed: {e}")
75
+ return False
76
+
77
+ def is_service_available(self) -> bool:
78
+ """Check if Drive service is available"""
79
+ return self.service is not None
80
+
81
+ def get_folder_id(self) -> str:
82
+ """Get the default folder ID"""
83
+ return self.folder_id
84
+
85
+ def set_folder_id(self, folder_id: str):
86
+ """Set a new default folder ID"""
87
+ self.folder_id = folder_id
88
+ logger.info(f"📁 Default folder ID updated to: {folder_id}")
89
+
90
+
91
+ # Convenience function for backward compatibility
92
+ def get_drive_service():
93
+ """Legacy function - returns DriveSaver instance"""
94
+ return DriveSaver()
95
+
96
+
97
+ def upload_to_folder(service, file_path, folder_id):
98
+ """Legacy function - uploads file to specified folder"""
99
+ if isinstance(service, DriveSaver):
100
+ return service.upload_csv_to_drive(file_path, folder_id)
101
+ else:
102
+ # Handle legacy service object
103
+ try:
104
+ file_name = os.path.basename(file_path)
105
+ media = MediaFileUpload(file_path, mimetype='text/csv')
106
+ metadata = {"name": file_name, "parents": [folder_id]}
107
+ return service.files().create(body=metadata, media_body=media, fields="id").execute()
108
+ except Exception as e:
109
+ logger.error(f"❌ Legacy upload failed: {e}")
110
+ return None
data/firebase_saver.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # firebase_saver.py
2
+ import os
3
+ import io
4
+ import re
5
+ import json
6
+ import logging
7
+ from datetime import datetime
8
+ from typing import Optional, Tuple, List
9
+
10
+ import pandas as pd
11
+
12
+ logger = logging.getLogger("firebase-saver")
13
+ logger.setLevel(logging.INFO)
14
+ if not logger.handlers:
15
+ _h = logging.StreamHandler()
16
+ _h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
17
+ logger.addHandler(_h)
18
+
19
+ # ---------- Constants (fixed as requested) ----------
20
+ FIXED_BUCKET = "skyledge-36b56.firebasestorage.app"
21
+ FIXED_PREFIX = "skyledge/processed" # no trailing slash
22
+
23
+ # Pattern: NNN_YYYY-MM-DD_processed.csv
24
+ FILENAME_RE = re.compile(r"^(?P<num>\d{3})_(?P<date>\d{4}-\d{2}-\d{2})_processed\.csv$")
25
+
26
+
27
+ def _parse_gs_uri(uri: Optional[str]):
28
+ if not uri or not uri.startswith("gs://"):
29
+ return None, None
30
+ path = uri[len("gs://"):]
31
+ parts = path.split("/", 1)
32
+ bucket = parts[0]
33
+ prefix = parts[1] if len(parts) > 1 else ""
34
+ return bucket, prefix
35
+
36
+
37
+ def _maybe_default_firebase_bucket(name: Optional[str]) -> Optional[str]:
38
+ # If user passed a project ID (no dot), convert to <project>.appspot.com
39
+ if name and "." not in name:
40
+ return f"{name}.appspot.com"
41
+ return name
42
+
43
+
44
+ # -------------------- Low-level clients --------------------
45
+
46
+ class _AdminClient:
47
+ """Firebase Admin SDK storage client."""
48
+ def __init__(self, bucket: str):
49
+ import firebase_admin
50
+ from firebase_admin import credentials, storage as fb_storage
51
+
52
+ raw = os.getenv("FIREBASE_ADMIN_JSON")
53
+ if not raw:
54
+ raise RuntimeError("FIREBASE_ADMIN_JSON not set")
55
+ info = json.loads(raw)
56
+ client_email = info.get("client_email")
57
+ cred = credentials.Certificate(info)
58
+
59
+ if not firebase_admin._apps:
60
+ firebase_admin.initialize_app(cred, {"storageBucket": bucket})
61
+
62
+ # fb_storage.bucket returns a google.cloud.storage.bucket.Bucket
63
+ self.bucket = fb_storage.bucket(bucket)
64
+ self._bucket_name = bucket
65
+ logger.info(f"✅ Firebase Admin initialized | bucket={bucket} as {client_email}")
66
+
67
+ # Uploads
68
+ def upload_from_filename(self, local_path: str, dest_path: str, content_type: str):
69
+ blob = self.bucket.blob(dest_path)
70
+ blob.cache_control = "no-store"
71
+ blob.upload_from_filename(local_path, content_type=content_type)
72
+
73
+ def upload_from_bytes(self, data: bytes, dest_path: str, content_type: str):
74
+ blob = self.bucket.blob(dest_path)
75
+ blob.cache_control = "no-store"
76
+ blob.upload_from_string(data, content_type=content_type)
77
+
78
+ # Listing (needs storage.objects.list permission)
79
+ def list_names(self, prefix: str) -> List[str]:
80
+ # Bucket.list_blobs works via the underlying GCS client
81
+ blobs = self.bucket.list_blobs(prefix=prefix)
82
+ return [b.name for b in blobs]
83
+
84
+ # Existence check (for collision-safe retry)
85
+ def blob_exists(self, path: str) -> bool:
86
+ blob = self.bucket.blob(path)
87
+ return blob.exists()
88
+
89
+
90
+ class _GCSClient:
91
+ """google-cloud-storage client."""
92
+ def __init__(self, bucket: str):
93
+ from google.cloud import storage
94
+ from google.oauth2 import service_account
95
+
96
+ raw = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON")
97
+ if not raw:
98
+ raise RuntimeError("FIREBASE_SERVICE_ACCOUNT_JSON not set")
99
+ info = json.loads(raw)
100
+ client_email = info.get("client_email")
101
+ creds = service_account.Credentials.from_service_account_info(info)
102
+ project_id = info.get("project_id")
103
+
104
+ self.client = storage.Client(credentials=creds, project=project_id)
105
+ self.bucket = self.client.bucket(bucket)
106
+ self._bucket_name = bucket
107
+ logger.info(f"✅ GCS client initialized | bucket={bucket} as {client_email}")
108
+
109
+ def upload_from_filename(self, local_path: str, dest_path: str, content_type: str):
110
+ blob = self.bucket.blob(dest_path)
111
+ blob.cache_control = "no-store"
112
+ blob.upload_from_filename(local_path, content_type=content_type)
113
+
114
+ def upload_from_bytes(self, data: bytes, dest_path: str, content_type: str):
115
+ blob = self.bucket.blob(dest_path)
116
+ blob.cache_control = "no-store"
117
+ blob.upload_from_string(data, content_type=content_type)
118
+
119
+ def list_names(self, prefix: str) -> List[str]:
120
+ blobs = self.client.list_blobs(self._bucket_name, prefix=prefix)
121
+ return [b.name for b in blobs]
122
+
123
+ def blob_exists(self, path: str) -> bool:
124
+ blob = self.bucket.blob(path)
125
+ return blob.exists(self.client)
126
+
127
+
128
+ # -------------------- Saver (high level) --------------------
129
+
130
+ class FirebaseSaver:
131
+ """
132
+ Fixed target:
133
+ Bucket: skyledge-36b56.firebasestorage.app
134
+ Prefix: skyledge/processed
135
+ Filename convention: NNN_YYYY-MM-DD_processed.csv (NNN is 001-based, zero-padded).
136
+ Auto-increments by listing current objects and picking max+1.
137
+ """
138
+
139
+ def __init__(self):
140
+ # Force fixed location regardless of env (as requested)
141
+ bucket_name = FIXED_BUCKET
142
+ self.prefix = FIXED_PREFIX
143
+
144
+ # Try Admin SDK first; fallback to GCS client
145
+ self.client = None
146
+ self.mode = None
147
+ try:
148
+ if os.getenv("FIREBASE_ADMIN_JSON"):
149
+ self.client = _AdminClient(bucket_name)
150
+ self.mode = "admin"
151
+ except Exception as e:
152
+ logger.warning(f"⚠️ Admin SDK init failed: {e}")
153
+
154
+ if self.client is None:
155
+ try:
156
+ self.client = _GCSClient(bucket_name)
157
+ self.mode = "gcs"
158
+ except Exception as e:
159
+ logger.error(f"❌ GCS client init failed: {e}")
160
+ raise
161
+
162
+ logger.info(f"📦 FirebaseSaver ready | mode={self.mode} bucket={bucket_name} prefix={self.prefix}")
163
+
164
+ def is_available(self) -> bool:
165
+ return self.client is not None
166
+
167
+ # ---------- Incremental naming helpers ----------
168
+
169
+ def _list_existing_filenames(self) -> List[str]:
170
+ """List object names under the fixed prefix, return just basenames under that folder."""
171
+ names = self.client.list_names(prefix=self.prefix + "/")
172
+ # keep only items immediately under prefix (not subfolders) & matching our filename pattern
173
+ base_names = []
174
+ for full in names:
175
+ # full looks like 'skyledge/processed/NNN_YYYY-MM-DD_processed.csv'
176
+ if not full.startswith(self.prefix + "/"):
177
+ continue
178
+ base = full[len(self.prefix) + 1:] # strip 'prefix/'
179
+ if "/" in base:
180
+ # skip nested items (none expected)
181
+ continue
182
+ if FILENAME_RE.match(base):
183
+ base_names.append(base)
184
+ return base_names
185
+
186
+ def _max_existing_id(self) -> int:
187
+ """Return max NNN found under prefix, or 0 if none."""
188
+ try:
189
+ base_names = self._list_existing_filenames()
190
+ except Exception as e:
191
+ logger.warning(f"⚠️ Unable to list existing objects; defaulting max_id=0: {e}")
192
+ return 0
193
+
194
+ max_id = 0
195
+ for name in base_names:
196
+ m = FILENAME_RE.match(name)
197
+ if not m:
198
+ continue
199
+ try:
200
+ num = int(m.group("num"))
201
+ if num > max_id:
202
+ max_id = num
203
+ except ValueError:
204
+ continue
205
+ return max_id
206
+
207
+ @staticmethod
208
+ def _format_id(n: int) -> str:
209
+ return f"{n:03d}"
210
+
211
+ @staticmethod
212
+ def _today_au() -> str:
213
+ # Use Australia/Melbourne local date; if zoneinfo unavailable, fall back to UTC date.
214
+ try:
215
+ from zoneinfo import ZoneInfo
216
+ dt = datetime.now(ZoneInfo("Australia/Melbourne"))
217
+ except Exception:
218
+ dt = datetime.utcnow()
219
+ return dt.strftime("%Y-%m-%d")
220
+
221
+ def _build_filename(self, n_int: int, date_str: Optional[str] = None) -> str:
222
+ date_val = (date_str or self._today_au())
223
+ return f"{self._format_id(n_int)}_{date_val}_processed.csv"
224
+
225
+ def _dest_path(self, filename: str) -> str:
226
+ return f"{self.prefix}/{filename}"
227
+
228
+ def _next_available_name(self, date_str: Optional[str] = None, max_retries: int = 5) -> Tuple[str, str]:
229
+ """
230
+ Compute the next file name by listing existing ones and incrementing.
231
+ Includes a collision check (exists) and retries if necessary.
232
+ Returns: (filename, full_gcs_path)
233
+ """
234
+ start = self._max_existing_id() + 1
235
+ n = start
236
+ for _ in range(max_retries):
237
+ candidate = self._build_filename(n, date_str=date_str)
238
+ dest_path = self._dest_path(candidate)
239
+ # collision check
240
+ if not self.client.blob_exists(dest_path):
241
+ return candidate, dest_path
242
+ n += 1
243
+
244
+ # As a final fallback, return the last tried (very unlikely to collide repeatedly)
245
+ candidate = self._build_filename(n, date_str=date_str)
246
+ return candidate, self._dest_path(candidate)
247
+
248
+ # ---------- Public save methods (incremental) ----------
249
+
250
+ def upload_file_with_increment(
251
+ self,
252
+ local_path: str,
253
+ date_str: Optional[str] = None,
254
+ content_type: str = "text/csv",
255
+ ) -> str:
256
+ """
257
+ Upload a local file using the next incremental name.
258
+ Returns the gs:// URL of the uploaded object (string) or "" on failure.
259
+ """
260
+ if not self.is_available():
261
+ logger.warning("⚠️ Firebase saver unavailable")
262
+ return ""
263
+ try:
264
+ filename, dest_path = self._next_available_name(date_str=date_str)
265
+ self.client.upload_from_filename(local_path, dest_path, content_type)
266
+ logger.info(f"✅ Uploaded file to gs://{FIXED_BUCKET}/{dest_path}")
267
+ return f"gs://{FIXED_BUCKET}/{dest_path}"
268
+ except Exception as e:
269
+ logger.error(f"❌ Firebase upload failed: {e}")
270
+ return ""
271
+
272
+ def upload_dataframe_with_increment(
273
+ self,
274
+ df: pd.DataFrame,
275
+ date_str: Optional[str] = None,
276
+ content_type: str = "text/csv",
277
+ ) -> str:
278
+ """
279
+ Upload a DataFrame (as CSV) using the next incremental name.
280
+ Returns the gs:// URL of the uploaded object (string) or "" on failure.
281
+ """
282
+ if not self.is_available():
283
+ logger.warning("⚠️ Firebase saver unavailable")
284
+ return ""
285
+ try:
286
+ buf = io.StringIO()
287
+ df.to_csv(buf, index=False)
288
+ data = buf.getvalue().encode("utf-8")
289
+
290
+ filename, dest_path = self._next_available_name(date_str=date_str)
291
+ self.client.upload_from_bytes(data, dest_path, content_type)
292
+ logger.info(f"✅ Uploaded DataFrame to gs://{FIXED_BUCKET}/{dest_path}")
293
+ return f"gs://{FIXED_BUCKET}/{dest_path}"
294
+ except Exception as e:
295
+ logger.error(f"❌ Firebase DF upload failed: {e}")
296
+ return ""
297
+
298
+
299
+ # ---------- Convenience free functions ----------
300
+
301
+ def save_csv_increment(csv_path: str, date_str: Optional[str] = None) -> str:
302
+ """
303
+ Upload local CSV with auto-incremented name 'NNN_YYYY-MM-DD_processed.csv'.
304
+ Returns gs:// URL or "".
305
+ """
306
+ saver = FirebaseSaver()
307
+ return saver.upload_file_with_increment(csv_path, date_str=date_str)
308
+
309
+ def save_dataframe_increment(df: pd.DataFrame, date_str: Optional[str] = None) -> str:
310
+ """
311
+ Upload DataFrame with auto-incremented name 'NNN_YYYY-MM-DD_processed.csv'.
312
+ Returns gs:// URL or "".
313
+ """
314
+ saver = FirebaseSaver()
315
+ return saver.upload_dataframe_with_increment(df, date_str=date_str)
data/mongo_saver.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MongoDB Operations for OBD Logger
2
+ # Handles data restructuring and saving to MongoDB Atlas
3
+
4
+ import os
5
+ import json
6
+ import logging
7
+ from datetime import datetime
8
+ from typing import Dict, List, Any, Optional
9
+ import pandas as pd
10
+ import numpy as np
11
+
12
+ # MongoDB dependencies
13
+ try:
14
+ from pymongo import MongoClient
15
+ from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
16
+ MONGODB_AVAILABLE = True
17
+ except ImportError:
18
+ MONGODB_AVAILABLE = False
19
+ print("⚠️ PyMongo not available. Install with: pip install pymongo")
20
+
21
+ # ───────────── Logging Setup ─────────────
22
+ logger = logging.getLogger("mongo-saver")
23
+ logger.setLevel(logging.INFO)
24
+ fmt = logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s")
25
+ handler = logging.StreamHandler()
26
+ handler.setFormatter(fmt)
27
+ logger.addHandler(handler)
28
+
29
+
30
+ class MongoSaver:
31
+ """Handles MongoDB operations for saving OBD data"""
32
+
33
+ def __init__(self, mongo_uri: str = None):
34
+ self.client = None
35
+ self.db = None
36
+ self.collection = None
37
+ self.mongo_uri = mongo_uri or os.getenv("MONGO_URI")
38
+ self._initialize_connection()
39
+
40
+ def _initialize_connection(self):
41
+ """Initialize MongoDB connection"""
42
+ if not MONGODB_AVAILABLE:
43
+ logger.error("❌ PyMongo not available. Cannot connect to MongoDB")
44
+ return
45
+
46
+ if not self.mongo_uri:
47
+ logger.error("❌ MongoDB URI not provided. Set MONGO_URI environment variable")
48
+ return
49
+
50
+ try:
51
+ # Connect with timeout and retry settings
52
+ self.client = MongoClient(
53
+ self.mongo_uri,
54
+ serverSelectionTimeoutMS=5000, # 5 second timeout
55
+ connectTimeoutMS=10000, # 10 second connection timeout
56
+ socketTimeoutMS=10000 # 10 second socket timeout
57
+ )
58
+
59
+ # Test connection
60
+ self.client.admin.command('ping')
61
+
62
+ # Set up database and collection
63
+ self.db = self.client.obd_logger
64
+ self.collection = self.db.obd_data
65
+
66
+ # Create indexes for better performance
67
+ self._create_indexes()
68
+
69
+ logger.info("✅ MongoDB connection established successfully")
70
+
71
+ except (ConnectionFailure, ServerSelectionTimeoutError) as e:
72
+ logger.error(f"❌ MongoDB connection failed: {e}")
73
+ self.client = None
74
+ self.db = None
75
+ self.collection = None
76
+ except Exception as e:
77
+ logger.error(f"❌ MongoDB initialization error: {e}")
78
+ self.client = None
79
+ self.db = None
80
+ self.collection = None
81
+
82
+ def _create_indexes(self):
83
+ """Create database indexes for better query performance"""
84
+ try:
85
+ # Index on timestamp for time-based queries
86
+ self.collection.create_index("timestamp")
87
+
88
+ # Index on driving_style for filtering
89
+ self.collection.create_index("driving_style")
90
+
91
+ # Compound index for common queries
92
+ self.collection.create_index([("timestamp", -1), ("driving_style", 1)])
93
+
94
+ # Index on session_id for session-based queries
95
+ self.collection.create_index("session_id")
96
+
97
+ logger.info("✅ Database indexes created successfully")
98
+
99
+ except Exception as e:
100
+ logger.warning(f"⚠️ Index creation failed: {e}")
101
+
102
+ def is_connected(self) -> bool:
103
+ """Check if MongoDB connection is active"""
104
+ if not self.client:
105
+ return False
106
+
107
+ try:
108
+ # Ping the database
109
+ self.client.admin.command('ping')
110
+ return True
111
+ except Exception:
112
+ return False
113
+
114
+ def save_csv_to_mongo(self, csv_file_path: str, session_id: str = None) -> bool:
115
+ """
116
+ Read CSV file and save data to MongoDB
117
+
118
+ Args:
119
+ csv_file_path (str): Path to the CSV file
120
+ session_id (str, optional): Unique identifier for this data session
121
+
122
+ Returns:
123
+ bool: True if save successful, False otherwise
124
+ """
125
+ if not self.is_connected():
126
+ logger.error("❌ MongoDB not connected")
127
+ return False
128
+
129
+ try:
130
+ # Read CSV file
131
+ df = pd.read_csv(csv_file_path)
132
+
133
+ if df.empty:
134
+ logger.warning("⚠️ CSV file is empty")
135
+ return False
136
+
137
+ # Generate session ID if not provided
138
+ if not session_id:
139
+ session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
140
+
141
+ # Convert DataFrame to MongoDB documents
142
+ documents = self._dataframe_to_documents(df, session_id)
143
+
144
+ # Insert documents into MongoDB
145
+ if documents:
146
+ result = self.collection.insert_many(documents)
147
+ logger.info(f"✅ Saved {len(result.inserted_ids)} records to MongoDB (Session: {session_id})")
148
+ return True
149
+ else:
150
+ logger.warning("⚠️ No valid documents to save")
151
+ return False
152
+
153
+ except Exception as e:
154
+ logger.error(f"❌ Failed to save CSV to MongoDB: {e}")
155
+ return False
156
+
157
+ def save_dataframe_to_mongo(self, df: pd.DataFrame, session_id: str = None) -> bool:
158
+ """
159
+ Save pandas DataFrame directly to MongoDB
160
+
161
+ Args:
162
+ df (pd.DataFrame): DataFrame to save
163
+ session_id (str, optional): Unique identifier for this data session
164
+
165
+ Returns:
166
+ bool: True if save successful, False otherwise
167
+ """
168
+ if not self.is_connected():
169
+ logger.error("❌ MongoDB not connected")
170
+ return False
171
+
172
+ try:
173
+ if df.empty:
174
+ logger.warning("⚠️ DataFrame is empty")
175
+ return False
176
+
177
+ # Generate session ID if not provided
178
+ if not session_id:
179
+ session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
180
+
181
+ # Convert DataFrame to MongoDB documents
182
+ documents = self._dataframe_to_documents(df, session_id)
183
+
184
+ # Insert documents into MongoDB
185
+ if documents:
186
+ result = self.collection.insert_many(documents)
187
+ logger.info(f"✅ Saved {len(result.inserted_ids)} records to MongoDB (Session: {session_id})")
188
+ return True
189
+ else:
190
+ logger.warning("⚠️ No valid documents to save")
191
+ return False
192
+
193
+ except Exception as e:
194
+ logger.error(f"❌ Failed to save DataFrame to MongoDB: {e}")
195
+ return False
196
+
197
+ def _dataframe_to_documents(self, df: pd.DataFrame, session_id: str) -> List[Dict[str, Any]]:
198
+ """
199
+ Convert pandas DataFrame to MongoDB documents
200
+
201
+ Args:
202
+ df (pd.DataFrame): Input DataFrame
203
+ session_id (str): Session identifier
204
+
205
+ Returns:
206
+ List[Dict[str, Any]]: List of MongoDB documents
207
+ """
208
+ documents = []
209
+
210
+ for index, row in df.iterrows():
211
+ try:
212
+ # Convert row to dictionary
213
+ doc = row.to_dict()
214
+
215
+ # Add metadata
216
+ doc['session_id'] = session_id
217
+ doc['imported_at'] = datetime.utcnow()
218
+ doc['record_index'] = index
219
+
220
+ # Handle timestamp conversion
221
+ if 'timestamp' in doc and pd.notna(doc['timestamp']):
222
+ try:
223
+ # Try to parse timestamp
224
+ if isinstance(doc['timestamp'], str):
225
+ doc['timestamp'] = pd.to_datetime(doc['timestamp'])
226
+ # Convert to datetime object
227
+ doc['timestamp'] = doc['timestamp'].to_pydatetime()
228
+ except Exception:
229
+ # Keep as string if parsing fails
230
+ pass
231
+
232
+ # Convert numeric types and handle NaN values
233
+ for key, value in doc.items():
234
+ if pd.isna(value):
235
+ doc[key] = None
236
+ elif isinstance(value, (np.integer, np.floating)):
237
+ doc[key] = value.item()
238
+ elif isinstance(value, np.bool_):
239
+ doc[key] = bool(value)
240
+
241
+ documents.append(doc)
242
+
243
+ except Exception as e:
244
+ logger.warning(f"⚠️ Failed to process row {index}: {e}")
245
+ continue
246
+
247
+ return documents
248
+
249
+ def query_data(self,
250
+ session_id: str = None,
251
+ driving_style: str = None,
252
+ start_time: datetime = None,
253
+ end_time: datetime = None,
254
+ limit: int = 1000) -> List[Dict[str, Any]]:
255
+ """
256
+ Query data from MongoDB
257
+
258
+ Args:
259
+ session_id (str, optional): Filter by session ID
260
+ driving_style (str, optional): Filter by driving style
261
+ start_time (datetime, optional): Start time filter
262
+ end_time (datetime, optional): End time filter
263
+ limit (int): Maximum number of records to return
264
+
265
+ Returns:
266
+ List[Dict[str, Any]]: Query results
267
+ """
268
+ if not self.is_connected():
269
+ logger.error("❌ MongoDB not connected")
270
+ return []
271
+
272
+ try:
273
+ # Build query filter
274
+ query_filter = {}
275
+
276
+ if session_id:
277
+ query_filter['session_id'] = session_id
278
+
279
+ if driving_style:
280
+ query_filter['driving_style'] = driving_style
281
+
282
+ if start_time or end_time:
283
+ time_filter = {}
284
+ if start_time:
285
+ time_filter['$gte'] = start_time
286
+ if end_time:
287
+ time_filter['$lte'] = end_time
288
+ query_filter['timestamp'] = time_filter
289
+
290
+ # Execute query
291
+ cursor = self.collection.find(query_filter).limit(limit)
292
+ results = list(cursor)
293
+
294
+ logger.info(f"✅ Query returned {len(results)} records")
295
+ return results
296
+
297
+ except Exception as e:
298
+ logger.error(f"❌ Query failed: {e}")
299
+ return []
300
+
301
+ def get_session_summary(self) -> List[Dict[str, Any]]:
302
+ """
303
+ Get summary of all data sessions
304
+
305
+ Returns:
306
+ List[Dict[str, Any]]: Session summaries
307
+ """
308
+ if not self.is_connected():
309
+ logger.error("❌ MongoDB not connected")
310
+ return []
311
+
312
+ try:
313
+ pipeline = [
314
+ {
315
+ '$group': {
316
+ '_id': '$session_id',
317
+ 'count': {'$sum': 1},
318
+ 'driving_styles': {'$addToSet': '$driving_style'},
319
+ 'first_record': {'$min': '$timestamp'},
320
+ 'last_record': {'$max': '$timestamp'},
321
+ 'imported_at': {'$first': '$imported_at'}
322
+ }
323
+ },
324
+ {
325
+ '$sort': {'imported_at': -1}
326
+ }
327
+ ]
328
+
329
+ results = list(self.collection.aggregate(pipeline))
330
+ logger.info(f"✅ Retrieved summary for {len(results)} sessions")
331
+ return results
332
+
333
+ except Exception as e:
334
+ logger.error(f"❌ Session summary failed: {e}")
335
+ return []
336
+
337
+ def close_connection(self):
338
+ """Close MongoDB connection"""
339
+ if self.client:
340
+ self.client.close()
341
+ logger.info("✅ MongoDB connection closed")
342
+
343
+ def __enter__(self):
344
+ """Context manager entry"""
345
+ return self
346
+
347
+ def __exit__(self, exc_type, exc_val, exc_tb):
348
+ """Context manager exit"""
349
+ self.close_connection()
350
+
351
+
352
+ # Convenience functions
353
+ def save_csv_to_mongo(csv_file_path: str, session_id: str = None) -> bool:
354
+ """Convenience function to save CSV to MongoDB"""
355
+ with MongoSaver() as saver:
356
+ return saver.save_csv_to_mongo(csv_file_path, session_id)
357
+
358
+
359
+ def save_dataframe_to_mongo(df: pd.DataFrame, session_id: str = None) -> bool:
360
+ """Convenience function to save DataFrame to MongoDB"""
361
+ with MongoSaver() as saver:
362
+ return saver.save_dataframe_to_mongo(df, session_id)
organization.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to reorganize existing models in HF repo to versioned structure.
4
+ This will move the current 3 .pkl files from root to v1.0 folder.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import tempfile
10
+ import json
11
+ from pathlib import Path
12
+
13
+ # Load environment variables from .env file
14
+ def load_env():
15
+ """Load environment variables from .env file"""
16
+ env_path = Path(__file__).parent / '.env'
17
+ if env_path.exists():
18
+ with open(env_path, 'r') as f:
19
+ for line in f:
20
+ line = line.strip()
21
+ if line and not line.startswith('#') and '=' in line:
22
+ key, value = line.split('=', 1)
23
+ os.environ[key] = value
24
+ print(f"✅ Loaded environment variables from {env_path}")
25
+ else:
26
+ print("⚠️ No .env file found")
27
+
28
+ # Load environment variables
29
+ load_env()
30
+
31
+ # Add train directory to path
32
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'train'))
33
+
34
+ def main():
35
+ """Main function to reorganize models"""
36
+ print("🔄 Reorganizing models in Hugging Face repository...")
37
+ print("=" * 60)
38
+
39
+ # Check if HF_TOKEN is set
40
+ if not os.getenv("HF_TOKEN"):
41
+ print("❌ Error: HF_TOKEN environment variable not set")
42
+ print("Please set your Hugging Face token:")
43
+ print("export HF_TOKEN=your_token_here")
44
+ return 1
45
+
46
+ # Check if we're in the right directory
47
+ if not os.path.exists("train/rlhf.py"):
48
+ print("❌ Error: Please run this script from the OBD_Logger root directory")
49
+ return 1
50
+
51
+ try:
52
+ # Import and run the reorganization
53
+ from train.move_models_to_v1 import move_models_to_v1
54
+
55
+ print("📥 Starting model reorganization...")
56
+ move_models_to_v1()
57
+
58
+ print("\n✅ Model reorganization completed successfully!")
59
+ print("📁 Your models are now organized in the v1.0 folder")
60
+ print("🔄 Future RLHF training will create v1.1, v1.2, etc.")
61
+ print("\nNext steps:")
62
+ print("1. Verify the models are in the v1.0 folder on Hugging Face")
63
+ print("2. Test the RLHF training with: curl -X POST 'http://localhost:8000/rlhf/train'")
64
+
65
+ return 0
66
+
67
+ except Exception as e:
68
+ print(f"❌ Reorganization failed: {e}")
69
+ print("\nTroubleshooting:")
70
+ print("1. Make sure HF_TOKEN is set correctly")
71
+ print("2. Check that you have write access to the repository")
72
+ print("3. Verify the repository name is correct")
73
+ return 1
74
+
75
+ if __name__ == "__main__":
76
+ exit(main())
organze.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple script to reorganize existing models in HF repo to versioned structure.
4
+ This will move the current 3 .pkl files from root to v1.0 folder.
5
+ """
6
+
7
+ import os
8
+ import tempfile
9
+ import json
10
+ from pathlib import Path
11
+ from huggingface_hub import HfApi, hf_hub_download, upload_file
12
+
13
+ def load_env():
14
+ """Load environment variables from .env file"""
15
+ env_path = Path(__file__).parent / '.env'
16
+ if env_path.exists():
17
+ with open(env_path, 'r') as f:
18
+ for line in f:
19
+ line = line.strip()
20
+ if line and not line.startswith('#') and '=' in line:
21
+ key, value = line.split('=', 1)
22
+ os.environ[key] = value
23
+ print(f"✅ Loaded environment variables from {env_path}")
24
+ else:
25
+ print("⚠️ No .env file found")
26
+
27
+ def main():
28
+ """Main function to reorganize models"""
29
+ print("🔄 Reorganizing models in Hugging Face repository...")
30
+ print("=" * 60)
31
+
32
+ # Load environment variables
33
+ load_env()
34
+
35
+ # Check if HF_TOKEN is set
36
+ hf_token = os.getenv("HF_TOKEN")
37
+ if not hf_token:
38
+ print("❌ Error: HF_TOKEN not found in environment")
39
+ return 1
40
+
41
+ print(f"✅ HF_TOKEN loaded: {hf_token[:10]}...")
42
+
43
+ # Configuration
44
+ repo_id = "BinKhoaLe1812/Driver_Behavior_OBD"
45
+ model_files = ["label_encoder_ul.pkl", "scaler_ul.pkl", "xgb_drivestyle_ul.pkl"]
46
+
47
+ print(f"📦 Target repository: {repo_id}")
48
+ print(f"📁 Model files to move: {model_files}")
49
+
50
+ # Initialize HF API
51
+ hf_api = HfApi(token=hf_token)
52
+
53
+ try:
54
+ # Create temporary directory
55
+ with tempfile.TemporaryDirectory() as temp_dir:
56
+ temp_path = Path(temp_dir)
57
+ print(f"📁 Using temporary directory: {temp_path}")
58
+
59
+ # Download existing model files
60
+ downloaded_files = []
61
+ for file in model_files:
62
+ try:
63
+ print(f"📥 Downloading {file}...")
64
+ local_path = hf_hub_download(
65
+ repo_id=repo_id,
66
+ filename=file,
67
+ repo_type="model",
68
+ token=hf_token
69
+ )
70
+ downloaded_files.append((file, local_path))
71
+ print(f"✅ Downloaded {file}")
72
+ except Exception as e:
73
+ print(f"⚠️ Could not download {file}: {e}")
74
+
75
+ if not downloaded_files:
76
+ print("⚠️ No model files found to move")
77
+ return 1
78
+
79
+ # Create v1.0 directory structure
80
+ v1_dir = temp_path / "v1.0"
81
+ v1_dir.mkdir(exist_ok=True)
82
+ print(f"📁 Created v1.0 directory: {v1_dir}")
83
+
84
+ # Copy files to v1.0 directory
85
+ for filename, local_path in downloaded_files:
86
+ dest_path = v1_dir / filename
87
+ import shutil
88
+ shutil.copy2(local_path, dest_path)
89
+ print(f"📦 Prepared {filename} for v1.0/")
90
+
91
+ # Create metadata.json for v1.0
92
+ metadata = {
93
+ "version": "1.0",
94
+ "model_type": "xgboost_classifier",
95
+ "created_at": "2024-12-01T00:00:00",
96
+ "description": "Initial model version - moved from root directory",
97
+ "framework": "xgboost",
98
+ "task": "driver_behavior_classification",
99
+ "labels": ["aggressive", "normal", "conservative"],
100
+ "features": "obd_sensor_data",
101
+ "files": [f[0] for f in downloaded_files]
102
+ }
103
+
104
+ metadata_path = v1_dir / "metadata.json"
105
+ with open(metadata_path, 'w') as f:
106
+ json.dump(metadata, f, indent=2)
107
+ print("📝 Created metadata.json for v1.0")
108
+
109
+ # Create README.md for v1.0
110
+ readme_content = """---
111
+ license: mit
112
+ tags:
113
+ - driver-behavior
114
+ - obd-data
115
+ - xgboost
116
+ - version-1.0
117
+ ---
118
+
119
+ # Driver Behavior Classification Model v1.0
120
+
121
+ Initial version of the driver behavior classification model.
122
+
123
+ ## Files
124
+
125
+ - `xgb_drivestyle_ul.pkl`: Main XGBoost model
126
+ - `label_encoder_ul.pkl`: Label encoder for behavior categories
127
+ - `scaler_ul.pkl`: Feature scaler
128
+ - `metadata.json`: Model metadata
129
+
130
+ ## Usage
131
+
132
+ ```python
133
+ import joblib
134
+
135
+ # Load the model
136
+ model = joblib.load('xgb_drivestyle_ul.pkl')
137
+ label_encoder = joblib.load('label_encoder_ul.pkl')
138
+ scaler = joblib.load('scaler_ul.pkl')
139
+
140
+ # Make predictions
141
+ predictions = model.predict(scaled_data)
142
+ behavior_labels = label_encoder.inverse_transform(predictions)
143
+ ```
144
+ """
145
+
146
+ readme_path = v1_dir / "README.md"
147
+ with open(readme_path, 'w') as f:
148
+ f.write(readme_content)
149
+ print("📖 Created README.md for v1.0")
150
+
151
+ # Upload files to v1.0 directory in HF repo
152
+ print("🚀 Uploading files to Hugging Face Hub...")
153
+ for file_path in v1_dir.iterdir():
154
+ if file_path.is_file():
155
+ hf_filename = f"v1.0/{file_path.name}"
156
+ print(f"📤 Uploading {file_path.name} to {hf_filename}...")
157
+ upload_file(
158
+ path_or_fileobj=str(file_path),
159
+ path_in_repo=hf_filename,
160
+ repo_id=repo_id,
161
+ repo_type="model",
162
+ token=hf_token,
163
+ commit_message=f"Add {file_path.name} to v1.0 directory"
164
+ )
165
+ print(f"✅ Uploaded {file_path.name} to v1.0/")
166
+
167
+ print("\n✅ Successfully moved models to v1.0 structure!")
168
+ print(f"📁 Models now located at: {repo_id}/v1.0/")
169
+ print("\nNext steps:")
170
+ print("1. Verify the models are in the v1.0 folder on Hugging Face")
171
+ print("2. Test the RLHF training with: curl -X POST 'http://localhost:8000/rlhf/train'")
172
+
173
+ return 0
174
+
175
+ except Exception as e:
176
+ print(f"❌ Reorganization failed: {e}")
177
+ print("\nTroubleshooting:")
178
+ print("1. Make sure HF_TOKEN is set correctly")
179
+ print("2. Check that you have write access to the repository")
180
+ print("3. Verify the repository name is correct")
181
+ return 1
182
+
183
+ if __name__ == "__main__":
184
+ exit(main())
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Server
2
+ fastapi
3
+ uvicorn[standard]
4
+ python-multipart
5
+ jinja2
6
+
7
+ # Data
8
+ pandas
9
+ numpy
10
+ scikit-learn
11
+
12
+ # ML Models
13
+ xgboost
14
+ joblib
15
+
16
+ # Drive
17
+ gspread
18
+ oauth2client
19
+ google-auth
20
+ google-auth-httplib2
21
+ google-auth-oauthlib
22
+ google-api-python-client
23
+
24
+ # Database
25
+ pymongo
26
+ google-cloud-storage
27
+ firebase-admin
28
+
29
+ # Visualize
30
+ matplotlib
31
+ seaborn
32
+
33
+ # HuggingFace
34
+ huggingface_hub==0.25.2
35
+
36
+ # Additional dependencies for RLHF training
37
+ pyarrow # For parquet file support
static/check.png ADDED
static/edit.png ADDED
static/icon.png ADDED
static/index.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>OBD-II Data Dashboard</title>
7
+ <link rel="website icon" type="png" href="/static/icon.png" >
8
+ <link rel="stylesheet" href="/static/styles.css">
9
+ </head>
10
+ <body>
11
+ <h1>OBD-II Data Pipeline Monitor</h1>
12
+ <div id="log-container"></div>
13
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
14
+ <script src="/static/script.js?v=2"></script>
15
+ </body>
16
+ </html>
static/script.js ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const expandedItems = JSON.parse(localStorage.getItem("expandedItems") || "{}");
2
+ const renamedLabels = JSON.parse(localStorage.getItem("renamedLabels") || "{}"); // Allow card to change their name (original identified by ts)
3
+ let previousKeys = [];
4
+ let previousEvents = {}; // Track event status to avoid redundant updates
5
+
6
+ // ─────────────────────────────────────────
7
+ // Refresh event per interval
8
+ // ─────────────────────────────────────────
9
+ async function fetchEvents() {
10
+ const res = await fetch('/events');
11
+ const data = await res.json();
12
+ renderEvents(data);
13
+ }
14
+
15
+ // ─────────────────────────────────────────
16
+ // Update or Create new card
17
+ // ─────────────────────────────────────────
18
+ function renderEvents(events) {
19
+ const container = document.getElementById('log-container');
20
+ const currentKeys = Object.keys(events).sort();
21
+ const newlyAdded = currentKeys.find(k => !previousKeys.includes(k));
22
+ previousKeys = currentKeys;
23
+
24
+ currentKeys.forEach(key => {
25
+ const event = events[key];
26
+ const existing = document.getElementById(`card-${key}`);
27
+ const prevStatus = previousEvents[key]?.status;
28
+
29
+ if (!existing) {
30
+ const card = createCard(key, event);
31
+ container.appendChild(card);
32
+ if (key === newlyAdded && event.status === 'done') {
33
+ setTimeout(() => card.scrollIntoView({ behavior: 'smooth', block: 'center' }), 300);
34
+ }
35
+ } else if (event.status !== prevStatus) {
36
+ updateCard(key, event); // Only update if status changed
37
+ }
38
+
39
+ previousEvents[key] = { status: event.status }; // Cache latest status
40
+ });
41
+ }
42
+
43
+ // ─────────────────────────────────────────
44
+ // Create new card on unmatched key
45
+ // ─────────────────────────────────────────
46
+ function createCard(key, event) {
47
+ const readable = renamedLabels[key] || formatTimestamp(key);
48
+ const safeKey = key.replace(/[:.]/g, "-");
49
+ const card = document.createElement('div');
50
+ card.id = `card-${key}`;
51
+ card.className = 'card';
52
+
53
+ const removeBtn = document.createElement('button');
54
+ removeBtn.className = 'btn-remove';
55
+ removeBtn.textContent = 'X';
56
+ removeBtn.onclick = () => removeItem(key);
57
+
58
+ const tsDiv = document.createElement('div');
59
+ tsDiv.className = 'timestamp';
60
+ tsDiv.innerHTML = `<span class="label-text">${readable}</span>`;
61
+
62
+ const editIcon = document.createElement('img');
63
+ editIcon.src = '/static/edit.png';
64
+ editIcon.className = 'icon-edit';
65
+ editIcon.onclick = () => toggleEditMode(tsDiv, key);
66
+ tsDiv.appendChild(editIcon);
67
+
68
+
69
+ const statusDiv = document.createElement('div');
70
+ statusDiv.className = 'status';
71
+
72
+ const actionDiv = document.createElement('div');
73
+ actionDiv.className = 'actions';
74
+
75
+ card.appendChild(removeBtn);
76
+ card.appendChild(tsDiv);
77
+ card.appendChild(statusDiv);
78
+ card.appendChild(actionDiv);
79
+
80
+ updateCardContent(card, key, event);
81
+ return card;
82
+ }
83
+
84
+ // ─────────────────────────────────────────
85
+ // Validate existing card
86
+ // ─────────────────────────────────────────
87
+ function updateCard(key, event) {
88
+ const card = document.getElementById(`card-${key}`);
89
+ if (card) {
90
+ updateCardContent(card, key, event);
91
+ }
92
+ }
93
+
94
+ // ─────────────────────────────────────────
95
+ // Update existing card content
96
+ // ─────────────────────────────────────────
97
+ function updateCardContent(card, key, event) {
98
+ const statusDiv = card.querySelector('.status');
99
+ const actionDiv = card.querySelector('.actions');
100
+ const safeKey = key.replace(/[:.]/g, "-");
101
+
102
+ actionDiv.innerHTML = '';
103
+ if (event.status === 'started') {
104
+ statusDiv.textContent = "Received signal. Data logging started.";
105
+ card.style.backgroundColor = '#780606';
106
+ } else if (event.status === 'processed') {
107
+ statusDiv.textContent = "Data logging finished. Start cleaning process.";
108
+ card.style.backgroundColor = '#2e6930';
109
+ } else if (event.status === 'done') {
110
+ statusDiv.textContent = "Cleaned data saved. Insights is ready.";
111
+ card.style.backgroundColor = '#8a00c2';
112
+
113
+ const expandBtn = document.createElement('button');
114
+ expandBtn.className = 'btn-expand';
115
+ expandBtn.textContent = expandedItems[key] ? 'Collapse' : 'Expand';
116
+ expandBtn.onclick = () => toggleExpand(key, expandBtn);
117
+
118
+ const expandDiv = document.createElement('div');
119
+ expandDiv.id = `expand-${key}`;
120
+ expandDiv.className = 'expanded-content';
121
+ if (expandedItems[key]) expandDiv.classList.add('show');
122
+
123
+ expandDiv.innerHTML = `
124
+ <img src="/plots/heatmap_${safeKey}.png" width="100%">
125
+ <img src="/plots/trend_${safeKey}.png" width="100%">
126
+ `;
127
+
128
+ actionDiv.appendChild(expandBtn);
129
+ actionDiv.appendChild(expandDiv);
130
+ }
131
+ }
132
+
133
+ // ─────────────────────────────────────────
134
+ // Toggle card expansion
135
+ // ─────────────────────────────────────────
136
+ function toggleExpand(key, btn) {
137
+ const el = document.getElementById(`expand-${key}`);
138
+ const showing = el.classList.contains('show');
139
+ if (showing) {
140
+ el.classList.remove('show');
141
+ expandedItems[key] = false;
142
+ btn.textContent = 'Expand';
143
+ } else {
144
+ el.classList.add('show');
145
+ expandedItems[key] = true;
146
+ btn.textContent = 'Collapse';
147
+ }
148
+ localStorage.setItem("expandedItems", JSON.stringify(expandedItems));
149
+ }
150
+
151
+ // ─────────────────────────────────────────
152
+ // Toggle card edit-view mode
153
+ // ─────────────────────────────────────────
154
+ function toggleEditMode(container, key) {
155
+ const icon = container.querySelector('.icon-edit');
156
+ if (!container.classList.contains('editing')) {
157
+ const span = container.querySelector('.label-text');
158
+ if (!span) return;
159
+ const input = document.createElement('input');
160
+ input.type = 'text';
161
+ input.value = span.textContent;
162
+ input.className = 'label-input';
163
+ span.replaceWith(input);
164
+ icon.src = '/static/check.png';
165
+ container.classList.add('editing');
166
+ } else {
167
+ const input = container.querySelector('.label-input');
168
+ if (!input) return;
169
+ const newLabel = input.value.trim() || formatTimestamp(key);
170
+ renamedLabels[key] = newLabel;
171
+ localStorage.setItem("renamedLabels", JSON.stringify(renamedLabels));
172
+ const newSpan = document.createElement('span');
173
+ newSpan.className = 'label-text';
174
+ newSpan.textContent = newLabel;
175
+ input.replaceWith(newSpan);
176
+ icon.src = '/static/edit.png';
177
+ container.classList.remove('editing');
178
+ }
179
+ }
180
+
181
+ // ─────────────────────────────────────────
182
+ // Remove a card item
183
+ // ─────────────────────────────────────────
184
+ function removeItem(key) {
185
+ const card = document.getElementById(`card-${key}`);
186
+ if (card) card.remove();
187
+ delete expandedItems[key];
188
+ delete previousEvents[key];
189
+ localStorage.setItem("expandedItems", JSON.stringify(expandedItems));
190
+ fetch(`/events/remove/${key}`, { method: 'DELETE' });
191
+ }
192
+
193
+ // ─────────────────────────────────────────
194
+ // Format timestamp as hh:mm dd/mm/yyyy
195
+ // ─────────────────────────────────────────
196
+ function formatTimestamp(norm_ts) {
197
+ try {
198
+ const parts = norm_ts.split("T");
199
+ if (parts.length !== 2) throw new Error("Invalid format");
200
+ // Extract date and time parts
201
+ const datePart = parts[0]; // e.g., "2025-05-21"
202
+ const timeParts = parts[1].split("-"); // ["hh", "mm", "ss"]
203
+ if (timeParts.length < 3) throw new Error("Incomplete time");
204
+ // Reformat
205
+ const [year, month, day] = datePart.split("-").map(Number);
206
+ let [hour, minute, second] = timeParts.map(Number);
207
+ hour = (hour - 2 + 24) % 24;
208
+ // Create Date in local time (note: month is 0-based)
209
+ const dt = new Date(year, month - 1, day, hour, minute, second);
210
+ // Write string
211
+ const timeStr = dt.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' });
212
+ const dateStr = dt.toLocaleDateString('en-AU');
213
+ return `${timeStr} ${dateStr}`;
214
+ } catch (err) {
215
+ console.warn("formatTimestamp fallback:", err.message);
216
+ return norm_ts;
217
+ }
218
+ }
219
+
220
+
221
+ // ─────────────────────────────────────────
222
+ // Sanitize filenames from timestamp
223
+ // ─────────────────────────────────────────
224
+ function sanitizeFilename(ts) {
225
+ return ts.replace(/:/g, '-').replace(/ /g, 'T').replace(/\//g, '-');
226
+ }
227
+
228
+ // ─────────────────────────────────────────
229
+ fetchEvents();
230
+ setInterval(fetchEvents, 1000);
static/styles.css ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Segoe UI', sans-serif;
3
+ background: linear-gradient(to bottom right, #eef1f7, #f9fafe);
4
+ margin: 0;
5
+ padding: 2rem;
6
+ color: #333;
7
+ }
8
+
9
+ h1 {
10
+ text-align: center;
11
+ margin-bottom: 2rem;
12
+ font-size: 2rem;
13
+ color: #2c3e50;
14
+ }
15
+
16
+ #log-container {
17
+ display: flex;
18
+ flex-direction: column;
19
+ gap: 1.5rem;
20
+ max-width: 960px;
21
+ margin: auto;
22
+ }
23
+
24
+ /* Card display */
25
+ .card {
26
+ border-radius: 10px;
27
+ padding: 1.2rem 1.5rem;
28
+ color: white;
29
+ position: relative;
30
+ box-shadow: 0 4px 10px rgba(0, 0, 0, 0.08);
31
+ transition: transform 0.3s ease, background-color 0.3s ease;
32
+ overflow: hidden;
33
+ }
34
+ .card:hover {
35
+ transform: translateY(-3px);
36
+ }
37
+ .status {
38
+ font-weight: 600;
39
+ font-size: 1.1rem;
40
+ }
41
+ .timestamp {
42
+ font-size: 0.95rem;
43
+ opacity: 0.9;
44
+ margin-top: 4px;
45
+ display: flex;
46
+ align-items: center;
47
+ gap: 8px;
48
+ }
49
+
50
+ .icon-edit {
51
+ width: 18px;
52
+ height: 18px;
53
+ cursor: pointer;
54
+ margin-left: 4px;
55
+ }
56
+ .label-input {
57
+ font-size: 1rem;
58
+ padding: 2px 6px;
59
+ border-radius: 4px;
60
+ border: 1px solid #ccc;
61
+ width: 160px;
62
+ }
63
+
64
+
65
+ /* All buttons */
66
+ .btn-expand,
67
+ .btn-remove {
68
+ margin-top: 1rem;
69
+ padding: 0.4rem 1.2rem;
70
+ cursor: pointer;
71
+ font-size: 0.9rem;
72
+ border: none;
73
+ border-radius: 4px;
74
+ transition: background-color 0.2s ease;
75
+ }
76
+ .btn-expand {
77
+ background-color: rgba(255, 255, 255, 0.25);
78
+ color: white;
79
+ }
80
+ .btn-expand:hover {
81
+ background-color: rgba(255, 255, 255, 0.4);
82
+ }
83
+ .btn-remove {
84
+ position: absolute;
85
+ top: 10px;
86
+ right: 14px;
87
+ background: rgba(255, 255, 255, 0.15);
88
+ color: white;
89
+ }
90
+ .btn-remove:hover {
91
+ background: rgba(255, 255, 255, 0.3);
92
+ }
93
+
94
+ /* Expanded content */
95
+ .expanded-content {
96
+ margin-top: 1.2rem;
97
+ animation: fadeIn 0.3s ease-in-out;
98
+ max-height: 0; /* You can adjust this limit */
99
+ overflow-y: auto; /* Allow vertical scroll */
100
+ transition: max-height 0.4s ease-in-out, opacity 0.3s ease;
101
+ opacity: 0;
102
+ padding-right: 5px; /* Optional: give room for scrollbar */
103
+ }
104
+ .expanded-content.show {
105
+ max-height: 1000px;
106
+ opacity: 1;
107
+ }
108
+ .expanded-content img {
109
+ margin-top: 1rem;
110
+ border-radius: 6px;
111
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
112
+ }
113
+
114
+ /* Colors */
115
+ .card.red {
116
+ background-color: #e74c3c;
117
+ }
118
+ .card.green {
119
+ background-color: #27ae60;
120
+ }
121
+ .card.purple {
122
+ background-color: #8e44ad;
123
+ }
124
+
125
+ /* Animation */
126
+ @keyframes fadeIn {
127
+ from {
128
+ opacity: 0;
129
+ transform: translateY(10px);
130
+ }
131
+ to {
132
+ opacity: 1;
133
+ transform: translateY(0);
134
+ }
135
+ }
train/README.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: tabular-classification
6
+ ---
7
+
8
+ # RLHF Training System
9
+
10
+ This directory contains the Reinforcement Learning from Human Feedback (RLHF) training pipeline for the driver behavior classification model.
11
+
12
+ ## Overview
13
+
14
+ The RLHF system enables continuous improvement of the driver behavior model by:
15
+ 1. Loading human-labeled data from Firebase storage (`skyledge/labeled`)
16
+ 2. Combining it with existing model predictions for reinforcement learning
17
+ 3. Retraining the XGBoost model with the enhanced dataset
18
+ 4. Saving new model checkpoints to Hugging Face Hub
19
+
20
+ ## Files
21
+
22
+ ### `loader.py`
23
+ - **Purpose**: Load labeled data from Firebase storage
24
+ - **Key Features**:
25
+ - Lists available labeled datasets from `skyledge/labeled` path
26
+ - Tracks already processed datasets in `trained.txt`
27
+ - Downloads and loads datasets into pandas DataFrames
28
+ - Prevents retraining on the same data
29
+
30
+ ### `saver.py`
31
+ - **Purpose**: Save trained models to Hugging Face Hub and local storage
32
+ - **Key Features**:
33
+ - Saves model components (XGBoost model, label encoder, scaler)
34
+ - Creates model metadata and README files
35
+ - Uploads to Hugging Face Hub with versioning
36
+ - Maintains local model directory structure
37
+
38
+ ### `rlhf.py`
39
+ - **Purpose**: Main RLHF training pipeline
40
+ - **Key Features**:
41
+ - Loads new labeled datasets
42
+ - Creates RLHF dataset by combining labeled data with model predictions
43
+ - Trains XGBoost model with enhanced dataset
44
+ - Evaluates model performance
45
+ - Coordinates with loader and saver modules
46
+
47
+ ## API Endpoints
48
+
49
+ The RLHF training system is integrated into the main FastAPI application with the following endpoints:
50
+
51
+ ### `POST /rlhf/train`
52
+ Trigger RLHF training session.
53
+
54
+ **Request Body:**
55
+ ```json
56
+ {
57
+ "max_datasets": 10,
58
+ "force_retrain": false
59
+ }
60
+ ```
61
+
62
+ **Response:**
63
+ ```json
64
+ {
65
+ "status": "success",
66
+ "model_version": "20241201_143022",
67
+ "datasets_processed": 5,
68
+ "samples_processed": 1250,
69
+ "performance_metrics": {
70
+ "accuracy": 0.892,
71
+ "cv_mean": 0.885,
72
+ "cv_std": 0.012
73
+ },
74
+ "timestamp": "2024-12-01T14:30:22"
75
+ }
76
+ ```
77
+
78
+ ### `GET /rlhf/status`
79
+ Get status of RLHF training system and available labeled data.
80
+
81
+ ### `GET /rlhf/trained-datasets`
82
+ Get list of datasets that have already been used for training.
83
+
84
+ ## Configuration
85
+
86
+ ### Environment Variables
87
+ - `HF_TOKEN`: Hugging Face authentication token
88
+ - `HF_MODEL_REPO`: Hugging Face model repository (default: `BinKhoaLe1812/Driver_Behavior_OBD`)
89
+ - `MODEL_DIR`: Local model directory (default: `/app/models/ul`)
90
+ - `FIREBASE_ADMIN_JSON`: Firebase Admin SDK credentials
91
+ - `FIREBASE_SERVICE_ACCOUNT_JSON`: Firebase service account credentials
92
+
93
+ ### Firebase Storage Structure
94
+ ```
95
+ skyledge-36b56.firebasestorage.app/
96
+ ├── skyledge/
97
+ │ ├── processed/ # Original processed data
98
+ │ ├── labeled/ # Human-labeled data for RLHF
99
+ │ │ ├── dataset1.csv
100
+ │ │ ├── dataset2.csv
101
+ │ │ └── trained.txt # Tracks processed datasets
102
+ │ └── logs/ # Training logs (future)
103
+ ```
104
+
105
+ ## Usage
106
+
107
+ ## Model Versioning
108
+
109
+ Each training session creates a new model version with timestamp format: `YYYYMMDD_HHMMSS`
110
+
111
+ Models are saved to:
112
+ - **Local**: `/app/models/ul/v{version}/`
113
+ - **Hugging Face**: `BinKhoaLe1812/Driver_Behavior_OBD`
114
+
115
+ ## Data Flow
116
+
117
+ 1. **Data Collection**: Human-labeled data stored in `skyledge/labeled/`
118
+ 2. **Training Trigger**: API endpoint or manual trigger
119
+ 3. **Data Loading**: Load new labeled datasets (skip already processed)
120
+ 4. **RLHF Dataset**: Combine labeled data with model predictions
121
+ 5. **Model Training**: Train XGBoost with enhanced dataset
122
+ 6. **Evaluation**: Calculate performance metrics
123
+ 7. **Model Saving**: Save to local storage and Hugging Face Hub
124
+ 8. **Tracking**: Update `trained.txt` with processed datasets
125
+
126
+ ## Performance Monitoring
127
+
128
+ The system tracks:
129
+ - Number of datasets processed
130
+ - Total samples processed
131
+ - Model accuracy and cross-validation scores
132
+ - Training timestamps and metadata
133
+
134
+ ## Error Handling
135
+
136
+ - Graceful handling of missing datasets
137
+ - Firebase connection failures
138
+ - Model loading/saving errors
139
+ - XGBoost compatibility issues
140
+ - Comprehensive logging throughout the pipeline
train/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # train package
2
+ # RLHF Training System for Driver Behavior Classification
3
+
4
+ from .rlhf import RLHFTrainer
5
+ from .loader import LabeledDataLoader
6
+ from .saver import ModelSaver
7
+
8
+ __all__ = ['RLHFTrainer', 'LabeledDataLoader', 'ModelSaver']
train/loader.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # loader.py
2
+ # Load labeled data from Firebase storage for RLHF training
3
+ import os
4
+ import json
5
+ import logging
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ from typing import List, Dict, Optional, Tuple, Any
9
+ from pathlib import Path
10
+
11
+ # Import Firebase client from the existing firebase_saver
12
+ import sys
13
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
14
+ from data.firebase_saver import _AdminClient, _GCSClient
15
+
16
+ logger = logging.getLogger("rlhf-loader")
17
+ logger.setLevel(logging.INFO)
18
+ if not logger.handlers:
19
+ _h = logging.StreamHandler()
20
+ _h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
21
+ logger.addHandler(_h)
22
+
23
+ # Firebase configuration
24
+ FIREBASE_BUCKET = "skyledge-36b56.firebasestorage.app"
25
+ LABELED_PREFIX = "skyledge/labeled"
26
+ RAW_PREFIX = "skyledge/raw"
27
+ PROCESSED_PREFIX = "skyledge/processed"
28
+ TRAINED_FILE = "trained.txt"
29
+
30
+ class LabeledDataLoader:
31
+ """
32
+ Load labeled data from Firebase storage for RLHF training.
33
+ Tracks already processed datasets to avoid retraining on the same data.
34
+ """
35
+
36
+ def __init__(self):
37
+ self.bucket_name = FIREBASE_BUCKET
38
+ self.prefix = LABELED_PREFIX
39
+ self.trained_file = TRAINED_FILE
40
+
41
+ # Initialize Firebase client
42
+ self.client = None
43
+ self.mode = None
44
+ try:
45
+ if os.getenv("FIREBASE_ADMIN_JSON"):
46
+ self.client = _AdminClient(self.bucket_name)
47
+ self.mode = "admin"
48
+ except Exception as e:
49
+ logger.warning(f"⚠️ Admin SDK init failed: {e}")
50
+
51
+ if self.client is None:
52
+ try:
53
+ self.client = _GCSClient(self.bucket_name)
54
+ self.mode = "gcs"
55
+ except Exception as e:
56
+ logger.error(f"❌ GCS client init failed: {e}")
57
+ raise
58
+
59
+ logger.info(f"📦 LabeledDataLoader ready | mode={self.mode} bucket={self.bucket_name} prefix={self.prefix}")
60
+
61
+ def _get_trained_datasets(self) -> List[str]:
62
+ """Load list of already trained datasets from trained.txt"""
63
+ try:
64
+ # Check if trained.txt exists in Firebase storage
65
+ trained_path = f"{self.prefix}/{self.trained_file}"
66
+ if self.client.blob_exists(trained_path):
67
+ # Download and read the file
68
+ blob = self.client.bucket.blob(trained_path)
69
+ content = blob.download_as_text()
70
+ trained_datasets = [line.strip() for line in content.split('\n') if line.strip()]
71
+ logger.info(f"📋 Loaded {len(trained_datasets)} already trained datasets")
72
+ return trained_datasets
73
+ else:
74
+ logger.info("📋 No trained.txt found, starting fresh")
75
+ return []
76
+ except Exception as e:
77
+ logger.warning(f"⚠️ Failed to load trained datasets: {e}")
78
+ return []
79
+
80
+ def _update_trained_datasets(self, new_datasets: List[str]):
81
+ """Update trained.txt with new dataset names"""
82
+ try:
83
+ # Get existing trained datasets
84
+ existing = self._get_trained_datasets()
85
+
86
+ # Add new datasets with timestamp
87
+ timestamp = datetime.now().isoformat()
88
+ new_entries = [f"{timestamp}:{dataset}" for dataset in new_datasets]
89
+ all_entries = existing + new_entries
90
+
91
+ # Upload updated file
92
+ trained_path = f"{self.prefix}/{self.trained_file}"
93
+ content = '\n'.join(all_entries)
94
+ self.client.upload_from_bytes(
95
+ content.encode('utf-8'),
96
+ trained_path,
97
+ "text/plain"
98
+ )
99
+ logger.info(f"✅ Updated trained.txt with {len(new_datasets)} new datasets")
100
+ except Exception as e:
101
+ logger.error(f"❌ Failed to update trained datasets: {e}")
102
+
103
+ def list_labeled_datasets(self) -> List[Dict[str, str]]:
104
+ """List all available labeled datasets in Firebase storage"""
105
+ try:
106
+ # List all blobs under the labeled prefix
107
+ blobs = self.client.bucket.list_blobs(prefix=f"{self.prefix}/")
108
+
109
+ datasets = []
110
+ trained_datasets = self._get_trained_datasets()
111
+
112
+ for blob in blobs:
113
+ # Skip the trained.txt file itself
114
+ if blob.name.endswith(f"/{self.trained_file}"):
115
+ continue
116
+
117
+ # Extract dataset name (relative to skyledge root)
118
+ dataset_name = blob.name.replace("skyledge/", "")
119
+
120
+ # Skip if already trained
121
+ if any(dataset_name in entry for entry in trained_datasets):
122
+ continue
123
+
124
+ # Get blob metadata
125
+ blob.reload()
126
+ datasets.append({
127
+ 'name': dataset_name,
128
+ 'path': blob.name,
129
+ 'size': blob.size,
130
+ 'created': blob.time_created.isoformat() if blob.time_created else None,
131
+ 'updated': blob.updated.isoformat() if blob.updated else None,
132
+ 'content_type': blob.content_type
133
+ })
134
+
135
+ logger.info(f"📊 Found {len(datasets)} new labeled datasets")
136
+ return datasets
137
+
138
+ except Exception as e:
139
+ logger.error(f"❌ Failed to list labeled datasets: {e}")
140
+ return []
141
+
142
+ def download_dataset(self, dataset_path: str, local_path: str) -> bool:
143
+ """Download a dataset from Firebase storage to local path"""
144
+ try:
145
+ blob = self.client.bucket.blob(dataset_path)
146
+ blob.download_to_filename(local_path)
147
+ logger.info(f"✅ Downloaded {dataset_path} to {local_path}")
148
+ return True
149
+ except Exception as e:
150
+ logger.error(f"❌ Failed to download {dataset_path}: {e}")
151
+ return False
152
+
153
+ def load_dataset(self, dataset_path: str) -> Optional[pd.DataFrame]:
154
+ """Load a dataset directly into a pandas DataFrame"""
155
+ try:
156
+ blob = self.client.bucket.blob(dataset_path)
157
+ content = blob.download_as_text()
158
+
159
+ # Try to determine file type and load accordingly
160
+ if dataset_path.endswith('.csv'):
161
+ df = pd.read_csv(pd.StringIO(content))
162
+ elif dataset_path.endswith('.json'):
163
+ df = pd.read_json(pd.StringIO(content))
164
+ elif dataset_path.endswith('.parquet'):
165
+ # For parquet, we need to download as bytes
166
+ blob_bytes = blob.download_as_bytes()
167
+ df = pd.read_parquet(pd.BytesIO(blob_bytes))
168
+ else:
169
+ # Default to CSV
170
+ df = pd.read_csv(pd.StringIO(content))
171
+
172
+ logger.info(f"✅ Loaded dataset {dataset_path} with shape {df.shape}")
173
+ return df
174
+
175
+ except Exception as e:
176
+ logger.error(f"❌ Failed to load dataset {dataset_path}: {e}")
177
+ return None
178
+
179
+ def get_new_datasets_for_training(self) -> List[Dict[str, str]]:
180
+ """Get list of new datasets that haven't been used for training yet"""
181
+ return self.list_labeled_datasets()
182
+
183
+ def mark_datasets_as_trained(self, dataset_names: List[str]):
184
+ """Mark datasets as trained to avoid retraining"""
185
+ self._update_trained_datasets(dataset_names)
186
+
187
+ def _parse_labeled_filename(self, filename: str) -> Dict[str, str]:
188
+ """
189
+ Parse labeled filename to extract original dataset information.
190
+ Format: {id}_{source}-{original_id}_{date}-labelled.csv
191
+ Example: 001_raw-002_2025-09-19-labelled.csv
192
+ """
193
+ try:
194
+ # Remove .csv extension
195
+ name = filename.replace('.csv', '')
196
+
197
+ # Split by underscore to get parts
198
+ parts = name.split('_')
199
+ if len(parts) < 4:
200
+ return {"error": f"Invalid filename format: {filename}"}
201
+
202
+ # Extract components
203
+ labeled_id = parts[0] # 001
204
+ source_and_original = parts[1] # raw-002 or processed-002
205
+ date = parts[2] # 2025-09-19
206
+
207
+ # Parse source and original ID
208
+ if '-' in source_and_original:
209
+ source, original_id = source_and_original.split('-', 1)
210
+ else:
211
+ source = source_and_original
212
+ original_id = "unknown"
213
+
214
+ return {
215
+ "labeled_id": labeled_id,
216
+ "source": source, # raw or processed
217
+ "original_id": original_id,
218
+ "date": date,
219
+ "original_filename": f"{original_id}_{date}-{source}.csv" if source != "unknown" else None
220
+ }
221
+ except Exception as e:
222
+ logger.warning(f"⚠️ Failed to parse filename {filename}: {e}")
223
+ return {"error": str(e)}
224
+
225
+ def _find_original_dataset(self, labeled_info: Dict[str, str]) -> Optional[str]:
226
+ """Find the original dataset path based on labeled file info"""
227
+ if labeled_info.get("error") or not labeled_info.get("original_filename"):
228
+ return None
229
+
230
+ source = labeled_info["source"]
231
+ original_filename = labeled_info["original_filename"]
232
+
233
+ if source == "raw":
234
+ return f"{self.RAW_PREFIX}/{original_filename}"
235
+ elif source == "processed":
236
+ return f"{self.PROCESSED_PREFIX}/{original_filename}"
237
+ else:
238
+ return None
239
+
240
+ def load_labeled_with_original(self, labeled_path: str) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Dict[str, str]]:
241
+ """
242
+ Load labeled dataset along with its original dataset for RLHF comparison.
243
+ Returns: (labeled_df, original_df, metadata)
244
+ """
245
+ try:
246
+ # Load labeled dataset
247
+ labeled_df = self.load_dataset(labeled_path)
248
+ if labeled_df is None:
249
+ return None, None, {"error": "Failed to load labeled dataset"}
250
+
251
+ # Parse filename to get original dataset info
252
+ filename = labeled_path.split('/')[-1]
253
+ labeled_info = self._parse_labeled_filename(filename)
254
+
255
+ if labeled_info.get("error"):
256
+ logger.warning(f"⚠️ Could not parse labeled filename: {labeled_info['error']}")
257
+ return labeled_df, None, labeled_info
258
+
259
+ # Find and load original dataset
260
+ original_path = self._find_original_dataset(labeled_info)
261
+ original_df = None
262
+
263
+ if original_path and self.client.blob_exists(original_path):
264
+ original_df = self.load_dataset(original_path)
265
+ if original_df is not None:
266
+ logger.info(f"✅ Loaded original dataset: {original_path}")
267
+ else:
268
+ logger.warning(f"⚠️ Failed to load original dataset: {original_path}")
269
+ else:
270
+ logger.warning(f"⚠️ Original dataset not found: {original_path}")
271
+
272
+ return labeled_df, original_df, labeled_info
273
+
274
+ except Exception as e:
275
+ logger.error(f"❌ Failed to load labeled with original: {e}")
276
+ return None, None, {"error": str(e)}
277
+
278
+ def create_training_batch(self, max_datasets: int = 10) -> Tuple[List[pd.DataFrame], List[str]]:
279
+ """
280
+ Create a training batch by loading new datasets.
281
+ Returns tuple of (dataframes, dataset_names)
282
+ """
283
+ datasets = self.get_new_datasets_for_training()
284
+
285
+ if not datasets:
286
+ logger.info("📭 No new datasets available for training")
287
+ return [], []
288
+
289
+ # Limit the number of datasets
290
+ datasets = datasets[:max_datasets]
291
+
292
+ dataframes = []
293
+ dataset_names = []
294
+
295
+ for dataset in datasets:
296
+ df = self.load_dataset(dataset['path'])
297
+ if df is not None:
298
+ dataframes.append(df)
299
+ dataset_names.append(dataset['name'])
300
+ else:
301
+ logger.warning(f"⚠️ Skipping dataset {dataset['name']} due to load failure")
302
+
303
+ if dataframes:
304
+ logger.info(f"📦 Created training batch with {len(dataframes)} datasets")
305
+ # Mark these datasets as trained
306
+ self.mark_datasets_as_trained(dataset_names)
307
+
308
+ return dataframes, dataset_names
309
+
310
+ def create_rlhf_training_batch(self, max_datasets: int = 10) -> Tuple[List[Dict[str, Any]], List[str]]:
311
+ """
312
+ Create RLHF training batch with both labeled and original datasets.
313
+ Returns tuple of (training_data, dataset_names)
314
+ Each training_data item contains: {'labeled_df', 'original_df', 'metadata'}
315
+ """
316
+ datasets = self.get_new_datasets_for_training()
317
+
318
+ if not datasets:
319
+ logger.info("📭 No new datasets available for RLHF training")
320
+ return [], []
321
+
322
+ # Limit the number of datasets
323
+ datasets = datasets[:max_datasets]
324
+
325
+ training_data = []
326
+ dataset_names = []
327
+
328
+ for dataset in datasets:
329
+ labeled_df, original_df, metadata = self.load_labeled_with_original(dataset['path'])
330
+
331
+ if labeled_df is not None:
332
+ training_item = {
333
+ 'labeled_df': labeled_df,
334
+ 'original_df': original_df,
335
+ 'metadata': metadata,
336
+ 'dataset_name': dataset['name']
337
+ }
338
+ training_data.append(training_item)
339
+ dataset_names.append(dataset['name'])
340
+ logger.info(f"✅ Loaded RLHF dataset: {dataset['name']} (original: {metadata.get('original_filename', 'N/A')})")
341
+ else:
342
+ logger.warning(f"⚠️ Skipping dataset {dataset['name']} due to load failure")
343
+
344
+ if training_data:
345
+ logger.info(f"📦 Created RLHF training batch with {len(training_data)} datasets")
346
+ # Mark these datasets as trained
347
+ self.mark_datasets_as_trained(dataset_names)
348
+
349
+ return training_data, dataset_names
350
+
351
+
352
+ def main():
353
+ """Test the loader functionality"""
354
+ loader = LabeledDataLoader()
355
+
356
+ # List available datasets
357
+ datasets = loader.list_labeled_datasets()
358
+ print(f"Available datasets: {len(datasets)}")
359
+ for dataset in datasets:
360
+ print(f" - {dataset['name']} ({dataset['size']} bytes)")
361
+
362
+ # Create a training batch
363
+ dataframes, names = loader.create_training_batch(max_datasets=5)
364
+ print(f"Training batch: {len(dataframes)} datasets")
365
+ for name in names:
366
+ print(f" - {name}")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
train/rlhf.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rlhf.py
2
+ # Reinforcement Learning from Human Feedback training pipeline
3
+ import os
4
+ import json
5
+ import logging
6
+ import pickle
7
+ import joblib
8
+ from datetime import datetime
9
+ from typing import List, Dict, Any, Optional, Tuple
10
+ import warnings
11
+
12
+ import pandas as pd
13
+ import numpy as np
14
+ from sklearn.model_selection import train_test_split, cross_val_score
15
+ from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
16
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
17
+ import xgboost as xgb
18
+
19
+ # Import our custom modules
20
+ from .loader import LabeledDataLoader
21
+ from .saver import ModelSaver
22
+
23
+ # Suppress warnings for cleaner output
24
+ warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
25
+ warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
26
+ warnings.filterwarnings("ignore", category=FutureWarning)
27
+
28
+ logger = logging.getLogger("rlhf-trainer")
29
+ logger.setLevel(logging.INFO)
30
+ if not logger.handlers:
31
+ _h = logging.StreamHandler()
32
+ _h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
33
+ logger.addHandler(_h)
34
+
35
+ class RLHFTrainer:
36
+ """
37
+ Reinforcement Learning from Human Feedback trainer for driver behavior classification.
38
+
39
+ This trainer:
40
+ 1. Loads human-labeled data from Firebase storage
41
+ 2. Combines it with existing model predictions for RLHF
42
+ 3. Retrains the XGBoost model with the combined dataset
43
+ 4. Evaluates performance and saves the new model
44
+ """
45
+
46
+ def __init__(self):
47
+ self.loader = LabeledDataLoader()
48
+ self.saver = ModelSaver()
49
+
50
+ # Model parameters
51
+ self.model_params = {
52
+ 'n_estimators': 100,
53
+ 'max_depth': 6,
54
+ 'learning_rate': 0.1,
55
+ 'subsample': 0.8,
56
+ 'colsample_bytree': 0.8,
57
+ 'random_state': 42,
58
+ 'use_label_encoder': False,
59
+ 'eval_metric': 'mlogloss'
60
+ }
61
+
62
+ # Feature columns to drop (non-predictive)
63
+ self.safe_drop = {
64
+ "timestamp", "driving_style", "ul_drivestyle", "gt_drivestyle",
65
+ "session_id", "imported_at", "record_index"
66
+ }
67
+
68
+ logger.info("🤖 RLHFTrainer initialized")
69
+
70
+ def _prepare_features(self, df: pd.DataFrame, expected_features: Optional[List[str]] = None) -> Tuple[np.ndarray, List[str]]:
71
+ """Prepare features for training"""
72
+ # Select numeric columns and drop non-feature columns
73
+ feature_cols = [c for c in df.columns
74
+ if c not in self.safe_drop and pd.api.types.is_numeric_dtype(df[c])]
75
+
76
+ X = df[feature_cols].copy()
77
+
78
+ # Ensure required features are present
79
+ if expected_features:
80
+ for col in expected_features:
81
+ if col not in X.columns:
82
+ X[col] = 0.0
83
+ X = X[expected_features] # Align order
84
+
85
+ # Handle missing values
86
+ X = X.fillna(0)
87
+
88
+ return X.values, feature_cols
89
+
90
+ def _prepare_labels(self, df: pd.DataFrame, label_column: str = "driving_style") -> np.ndarray:
91
+ """Prepare labels for training"""
92
+ if label_column not in df.columns:
93
+ raise ValueError(f"Label column '{label_column}' not found in data")
94
+
95
+ return df[label_column].values
96
+
97
+ def _load_existing_model(self) -> Tuple[Any, Any, Any, List[str]]:
98
+ """Load existing model components, downloading latest version if needed"""
99
+ try:
100
+ # First, try to download the latest model
101
+ logger.info("🔄 Checking for latest model version...")
102
+ try:
103
+ from utils.download import download_latest_models
104
+ download_latest_models()
105
+ except Exception as e:
106
+ logger.warning(f"⚠️ Failed to download latest models: {e}")
107
+
108
+ model_dir = os.getenv("MODEL_DIR", "/app/models/ul")
109
+
110
+ model_path = os.path.join(model_dir, "xgb_drivestyle_ul.pkl")
111
+ le_path = os.path.join(model_dir, "label_encoder_ul.pkl")
112
+ scaler_path = os.path.join(model_dir, "scaler_ul.pkl")
113
+
114
+ # Load with compatibility fixes
115
+ model = self._load_model_with_compatibility(model_path)
116
+ label_encoder = joblib.load(le_path)
117
+ scaler = joblib.load(scaler_path)
118
+
119
+ # Get expected features
120
+ expected_features = None
121
+ if hasattr(scaler, "feature_names_in_"):
122
+ expected_features = list(scaler.feature_names_in_)
123
+ elif hasattr(model, "feature_names_in_"):
124
+ expected_features = list(model.feature_names_in_)
125
+
126
+ logger.info(f"✅ Loaded existing model with {len(expected_features) if expected_features else 'unknown'} features")
127
+ return model, label_encoder, scaler, expected_features
128
+
129
+ except Exception as e:
130
+ logger.warning(f"⚠️ Failed to load existing model: {e}")
131
+ return None, None, None, None
132
+
133
+ def _load_model_with_compatibility(self, model_path: str) -> Any:
134
+ """Load model with XGBoost compatibility fixes"""
135
+ try:
136
+ model = joblib.load(model_path)
137
+
138
+ # Fix XGBoost compatibility issues
139
+ if hasattr(model, 'get_booster'): # This is an XGBoost model
140
+ # Remove deprecated attributes
141
+ deprecated_attrs = [
142
+ 'use_label_encoder', '_le', '_label_encoder',
143
+ 'use_label_encoder_', '_le_', '_label_encoder_'
144
+ ]
145
+ for attr in deprecated_attrs:
146
+ if hasattr(model, attr):
147
+ try:
148
+ delattr(model, attr)
149
+ except (AttributeError, TypeError):
150
+ pass
151
+
152
+ # Set use_label_encoder to False
153
+ if hasattr(model, 'set_params'):
154
+ try:
155
+ model.set_params(use_label_encoder=False)
156
+ except Exception:
157
+ pass
158
+
159
+ return model
160
+
161
+ except Exception as e:
162
+ logger.error(f"❌ Failed to load model: {e}")
163
+ raise
164
+
165
+ def _create_rlhf_dataset(self, training_data: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
166
+ """Create RLHF dataset by combining labeled data with original data and model predictions"""
167
+ try:
168
+ # Load existing model for generating predictions
169
+ existing_model, label_encoder, scaler, expected_features = self._load_existing_model()
170
+
171
+ if existing_model is None:
172
+ logger.warning("⚠️ No existing model found, using only labeled data")
173
+ return self._prepare_rlhf_from_labeled_only(training_data)
174
+
175
+ # Combine all labeled datasets
176
+ labeled_dfs = [item['labeled_df'] for item in training_data if item['labeled_df'] is not None]
177
+ original_dfs = [item['original_df'] for item in training_data if item['original_df'] is not None]
178
+
179
+ combined_labeled_df = pd.concat(labeled_dfs, ignore_index=True)
180
+
181
+ # Prepare features and labels from labeled data
182
+ X_labeled, feature_cols = self._prepare_features(combined_labeled_df, expected_features)
183
+ y_labeled = self._prepare_labels(combined_labeled_df)
184
+
185
+ # Scale features
186
+ X_labeled_scaled = scaler.transform(X_labeled)
187
+
188
+ # Generate model predictions on original data for comparison
189
+ model_predictions = []
190
+ prediction_confidence = []
191
+
192
+ if original_dfs:
193
+ combined_original_df = pd.concat(original_dfs, ignore_index=True)
194
+ X_original, _ = self._prepare_features(combined_original_df, expected_features)
195
+ X_original_scaled = scaler.transform(X_original)
196
+
197
+ # Get model predictions on original data
198
+ original_predictions = existing_model.predict(X_original_scaled)
199
+ model_predictions.extend(original_predictions)
200
+
201
+ # Get prediction probabilities for confidence
202
+ if hasattr(existing_model, 'predict_proba'):
203
+ proba = existing_model.predict_proba(X_original_scaled)
204
+ confidence = np.max(proba, axis=1)
205
+ prediction_confidence.extend(confidence)
206
+
207
+ # Create RLHF dataset with preference learning
208
+ # The labeled data represents the "correct" behavior (human preference)
209
+ # The model predictions on original data represent what the model thought was correct
210
+
211
+ # For RLHF, we want to learn from the difference between model predictions and human labels
212
+ rlhf_metadata = {
213
+ "labeled_samples": len(X_labeled),
214
+ "original_samples": len(model_predictions) if model_predictions else 0,
215
+ "model_confidence": np.mean(prediction_confidence) if prediction_confidence else 0.0,
216
+ "datasets_processed": len(training_data)
217
+ }
218
+
219
+ logger.info(f"📊 Created RLHF dataset: {len(X_labeled)} labeled samples, {len(model_predictions)} original samples")
220
+ logger.info(f"📊 Model confidence on original data: {rlhf_metadata['model_confidence']:.3f}")
221
+
222
+ return X_labeled_scaled, y_labeled, rlhf_metadata
223
+
224
+ except Exception as e:
225
+ logger.error(f"❌ Failed to create RLHF dataset: {e}")
226
+ raise
227
+
228
+ def _prepare_rlhf_from_labeled_only(self, training_data: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
229
+ """Prepare RLHF dataset from labeled data only (when no existing model)"""
230
+ labeled_dfs = [item['labeled_df'] for item in training_data if item['labeled_df'] is not None]
231
+ combined_df = pd.concat(labeled_dfs, ignore_index=True)
232
+
233
+ # Prepare features
234
+ X, feature_cols = self._prepare_features(combined_df)
235
+ y = self._prepare_labels(combined_df)
236
+
237
+ # Create and fit scaler
238
+ scaler = StandardScaler()
239
+ X_scaled = scaler.fit_transform(X)
240
+
241
+ rlhf_metadata = {
242
+ "labeled_samples": len(X),
243
+ "original_samples": 0,
244
+ "model_confidence": 0.0,
245
+ "datasets_processed": len(training_data)
246
+ }
247
+
248
+ return X_scaled, y, rlhf_metadata
249
+
250
+
251
+ def _train_model(self, X: np.ndarray, y: np.ndarray,
252
+ existing_model: Optional[Any] = None) -> Tuple[Any, Any, Any]:
253
+ """Train the XGBoost model"""
254
+ try:
255
+ # Create label encoder
256
+ label_encoder = LabelEncoder()
257
+ y_encoded = label_encoder.fit_transform(y)
258
+
259
+ # Create scaler
260
+ scaler = StandardScaler()
261
+ X_scaled = scaler.fit_transform(X)
262
+
263
+ # Split data
264
+ X_train, X_test, y_train, y_test = train_test_split(
265
+ X_scaled, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
266
+ )
267
+
268
+ # Create and train model
269
+ model = xgb.XGBClassifier(**self.model_params)
270
+
271
+ # If we have an existing model, we can use it for warm start or transfer learning
272
+ if existing_model is not None:
273
+ logger.info("🔄 Using existing model for warm start")
274
+ # For XGBoost, we can't directly warm start, but we can use similar parameters
275
+ # and potentially use the existing model's predictions as additional features
276
+
277
+ # Train the model
278
+ model.fit(X_train, y_train,
279
+ eval_set=[(X_test, y_test)],
280
+ early_stopping_rounds=10,
281
+ verbose=False)
282
+
283
+ # Evaluate
284
+ y_pred = model.predict(X_test)
285
+ accuracy = accuracy_score(y_test, y_pred)
286
+
287
+ logger.info(f"✅ Model trained with accuracy: {accuracy:.4f}")
288
+
289
+ return model, label_encoder, scaler
290
+
291
+ except Exception as e:
292
+ logger.error(f"❌ Model training failed: {e}")
293
+ raise
294
+
295
+ def _evaluate_model(self, model: Any, label_encoder: Any, scaler: Any,
296
+ X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
297
+ """Evaluate model performance"""
298
+ try:
299
+ # Prepare test data
300
+ X_scaled = scaler.transform(X)
301
+ y_encoded = label_encoder.transform(y)
302
+
303
+ # Make predictions
304
+ y_pred = model.predict(X_scaled)
305
+
306
+ # Calculate metrics
307
+ accuracy = accuracy_score(y_encoded, y_pred)
308
+
309
+ # Cross-validation score
310
+ cv_scores = cross_val_score(model, X_scaled, y_encoded, cv=5)
311
+ cv_mean = cv_scores.mean()
312
+ cv_std = cv_scores.std()
313
+
314
+ metrics = {
315
+ "accuracy": accuracy,
316
+ "cv_mean": cv_mean,
317
+ "cv_std": cv_std,
318
+ "cv_scores": cv_scores.tolist()
319
+ }
320
+
321
+ logger.info(f"📊 Model evaluation: accuracy={accuracy:.4f}, cv_mean={cv_mean:.4f}±{cv_std:.4f}")
322
+ return metrics
323
+
324
+ except Exception as e:
325
+ logger.error(f"❌ Model evaluation failed: {e}")
326
+ return {"accuracy": 0.0, "cv_mean": 0.0, "cv_std": 0.0}
327
+
328
+ def train(self, max_datasets: int = 10) -> Dict[str, Any]:
329
+ """Main training pipeline"""
330
+ try:
331
+ logger.info("🚀 Starting RLHF training pipeline")
332
+
333
+ # Load new labeled datasets with original data for RLHF
334
+ training_data, dataset_names = self.loader.create_rlhf_training_batch(max_datasets=max_datasets)
335
+
336
+ if not training_data:
337
+ logger.warning("⚠️ No new datasets available for RLHF training")
338
+ return {"status": "no_data", "message": "No new datasets available"}
339
+
340
+ logger.info(f"📦 Loaded {len(training_data)} datasets for RLHF training")
341
+
342
+ # Create RLHF dataset
343
+ X, y, rlhf_metadata = self._create_rlhf_dataset(training_data)
344
+
345
+ # Load existing model for comparison
346
+ existing_model, existing_le, existing_scaler, expected_features = self._load_existing_model()
347
+
348
+ # Train new model
349
+ model, label_encoder, scaler = self._train_model(X, y, existing_model)
350
+
351
+ # Evaluate model
352
+ metrics = self._evaluate_model(model, label_encoder, scaler, X, y)
353
+
354
+ # Generate model version using semantic versioning
355
+ model_version = self.saver._get_next_version()
356
+
357
+ # Prepare training data info
358
+ training_data_info = {
359
+ "datasets": dataset_names,
360
+ "total_samples": len(X),
361
+ "training_date": datetime.now().isoformat(),
362
+ "features_count": X.shape[1]
363
+ }
364
+
365
+ # Prepare training log
366
+ training_log = {
367
+ "datasets_used": dataset_names,
368
+ "samples_processed": len(X),
369
+ "model_parameters": self.model_params,
370
+ "performance_metrics": metrics,
371
+ "training_duration": "N/A", # Could be tracked if needed
372
+ "existing_model_used": existing_model is not None
373
+ }
374
+
375
+ # Save model
376
+ save_result = self.saver.save_complete_model(
377
+ model=model,
378
+ label_encoder=label_encoder,
379
+ scaler=scaler,
380
+ model_version=model_version,
381
+ training_data_info=training_data_info,
382
+ performance_metrics=metrics,
383
+ training_log=training_log,
384
+ rlhf_metadata=rlhf_metadata
385
+ )
386
+
387
+ result = {
388
+ "status": "success",
389
+ "model_version": model_version,
390
+ "datasets_processed": len(dataset_names),
391
+ "samples_processed": len(X),
392
+ "performance_metrics": metrics,
393
+ "save_result": save_result,
394
+ "training_log": training_log
395
+ }
396
+
397
+ logger.info(f"✅ RLHF training completed successfully: v{model_version}")
398
+ return result
399
+
400
+ except Exception as e:
401
+ logger.error(f"❌ RLHF training failed: {e}")
402
+ return {
403
+ "status": "error",
404
+ "error": str(e),
405
+ "timestamp": datetime.now().isoformat()
406
+ }
407
+
408
+
409
+ def main():
410
+ """Test the RLHF trainer"""
411
+ try:
412
+ trainer = RLHFTrainer()
413
+ result = trainer.train(max_datasets=5)
414
+ print(f"Training result: {result}")
415
+ except Exception as e:
416
+ print(f"Training failed: {e}")
417
+
418
+
419
+ if __name__ == "__main__":
420
+ main()
train/saver.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # saver.py
2
+ # Model saving functions for RLHF training
3
+ import os
4
+ import json
5
+ import logging
6
+ import pickle
7
+ import joblib
8
+ from datetime import datetime
9
+ from typing import Dict, Any, Optional
10
+ from pathlib import Path
11
+
12
+ from huggingface_hub import HfApi, Repository
13
+ import pandas as pd
14
+ import numpy as np
15
+
16
+ logger = logging.getLogger("rlhf-saver")
17
+ logger.setLevel(logging.INFO)
18
+ if not logger.handlers:
19
+ _h = logging.StreamHandler()
20
+ _h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
21
+ logger.addHandler(_h)
22
+
23
+ class ModelSaver:
24
+ """
25
+ Save trained models to Hugging Face Hub and local storage.
26
+ Handles model artifacts, metadata, and versioning.
27
+ """
28
+
29
+ def __init__(self):
30
+ self.hf_token = os.getenv("HF_TOKEN")
31
+ if not self.hf_token:
32
+ raise RuntimeError("HF_TOKEN environment variable not set")
33
+
34
+ self.hf_api = HfApi(token=self.hf_token)
35
+ self.repo_id = os.getenv("HF_MODEL_REPO", "BinKhoaLe1812/Driver_Behavior_OBD")
36
+
37
+ # Local model directory
38
+ self.local_model_dir = Path(os.getenv("MODEL_DIR", "/app/models/ul"))
39
+ self.local_model_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ logger.info(f"📦 ModelSaver ready | repo={self.repo_id}")
42
+
43
+ def _get_next_version(self) -> str:
44
+ """Get the next version number (1.0, 1.1, 1.2, ..., 1.9, 2.0, etc.)"""
45
+ try:
46
+ # List existing versions in HF repo
47
+ repo_files = self.hf_api.list_repo_files(
48
+ repo_id=self.repo_id,
49
+ repo_type="model"
50
+ )
51
+
52
+ # Find version directories (v1.0, v1.1, etc.)
53
+ version_dirs = [f for f in repo_files if f.startswith('v') and '/' not in f]
54
+ versions = []
55
+
56
+ for v_dir in version_dirs:
57
+ try:
58
+ version_str = v_dir[1:] # Remove 'v' prefix
59
+ if '.' in version_str:
60
+ major, minor = version_str.split('.')
61
+ versions.append((int(major), int(minor)))
62
+ except (ValueError, IndexError):
63
+ continue
64
+
65
+ if not versions:
66
+ return "1.0"
67
+
68
+ # Sort versions and get the latest
69
+ versions.sort()
70
+ latest_major, latest_minor = versions[-1]
71
+
72
+ # Increment version
73
+ if latest_minor < 9:
74
+ return f"{latest_major}.{latest_minor + 1}"
75
+ else:
76
+ return f"{latest_major + 1}.0"
77
+
78
+ except Exception as e:
79
+ logger.warning(f"⚠️ Failed to get next version from HF repo: {e}")
80
+ # Fallback to timestamp-based version
81
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
82
+
83
+ def _create_model_metadata(self,
84
+ model_type: str,
85
+ training_data_info: Dict[str, Any],
86
+ performance_metrics: Dict[str, float],
87
+ model_version: str,
88
+ rlhf_metadata: Dict[str, Any] = None) -> Dict[str, Any]:
89
+ """Create metadata for the trained model"""
90
+ metadata = {
91
+ "model_type": model_type,
92
+ "version": model_version,
93
+ "created_at": datetime.now().isoformat(),
94
+ "training_data": training_data_info,
95
+ "performance_metrics": performance_metrics,
96
+ "framework": "xgboost",
97
+ "task": "driver_behavior_classification",
98
+ "labels": ["aggressive", "normal", "conservative"], # Based on ul_label.py
99
+ "features": "obd_sensor_data",
100
+ "rlhf_metadata": rlhf_metadata or {}
101
+ }
102
+ return metadata
103
+
104
+ def save_model_locally(self,
105
+ model: Any,
106
+ label_encoder: Any,
107
+ scaler: Any,
108
+ model_version: str,
109
+ metadata: Dict[str, Any]) -> Dict[str, str]:
110
+ """Save model components locally with versioning"""
111
+ try:
112
+ # Create versioned directory
113
+ version_dir = self.local_model_dir / f"v{model_version}"
114
+ version_dir.mkdir(exist_ok=True)
115
+
116
+ # Save model components
117
+ model_path = version_dir / "xgb_drivestyle_ul.pkl"
118
+ le_path = version_dir / "label_encoder_ul.pkl"
119
+ scaler_path = version_dir / "scaler_ul.pkl"
120
+ metadata_path = version_dir / "metadata.json"
121
+
122
+ # Save using joblib for better compatibility
123
+ joblib.dump(model, model_path)
124
+ joblib.dump(label_encoder, le_path)
125
+ joblib.dump(scaler, scaler_path)
126
+
127
+ # Save metadata
128
+ with open(metadata_path, 'w') as f:
129
+ json.dump(metadata, f, indent=2)
130
+
131
+ # Also save to the main model directory (for current usage)
132
+ joblib.dump(model, self.local_model_dir / "xgb_drivestyle_ul.pkl")
133
+ joblib.dump(label_encoder, self.local_model_dir / "label_encoder_ul.pkl")
134
+ joblib.dump(scaler, self.local_model_dir / "scaler_ul.pkl")
135
+
136
+ logger.info(f"✅ Model saved locally to {version_dir}")
137
+
138
+ return {
139
+ "model_path": str(model_path),
140
+ "label_encoder_path": str(le_path),
141
+ "scaler_path": str(scaler_path),
142
+ "metadata_path": str(metadata_path)
143
+ }
144
+
145
+ except Exception as e:
146
+ logger.error(f"❌ Failed to save model locally: {e}")
147
+ raise
148
+
149
+ def save_model_to_hf(self,
150
+ model: Any,
151
+ label_encoder: Any,
152
+ scaler: Any,
153
+ model_version: str,
154
+ metadata: Dict[str, Any],
155
+ training_data_info: Dict[str, Any]) -> str:
156
+ """Save model to Hugging Face Hub"""
157
+ try:
158
+ # Create temporary directory for upload
159
+ temp_dir = Path(f"/tmp/hf_upload_{model_version}")
160
+ temp_dir.mkdir(exist_ok=True)
161
+
162
+ # Save model components
163
+ model_path = temp_dir / "xgb_drivestyle_ul.pkl"
164
+ le_path = temp_dir / "label_encoder_ul.pkl"
165
+ scaler_path = temp_dir / "scaler_ul.pkl"
166
+ metadata_path = temp_dir / "metadata.json"
167
+ readme_path = temp_dir / "README.md"
168
+
169
+ # Save using joblib
170
+ joblib.dump(model, model_path)
171
+ joblib.dump(label_encoder, le_path)
172
+ joblib.dump(scaler, scaler_path)
173
+
174
+ # Save metadata
175
+ with open(metadata_path, 'w') as f:
176
+ json.dump(metadata, f, indent=2)
177
+
178
+ # Create README
179
+ readme_content = self._create_readme(metadata, training_data_info)
180
+ with open(readme_path, 'w') as f:
181
+ f.write(readme_content)
182
+
183
+ # Upload to Hugging Face Hub
184
+ self.hf_api.upload_folder(
185
+ folder_path=str(temp_dir),
186
+ repo_id=self.repo_id,
187
+ repo_type="model",
188
+ commit_message=f"RLHF training update v{model_version}",
189
+ ignore_patterns=["*.tmp", "*.log"]
190
+ )
191
+
192
+ # Clean up temp directory
193
+ import shutil
194
+ shutil.rmtree(temp_dir)
195
+
196
+ logger.info(f"✅ Model uploaded to Hugging Face Hub: {self.repo_id}")
197
+ return f"https://huggingface.co/{self.repo_id}"
198
+
199
+ except Exception as e:
200
+ logger.error(f"❌ Failed to save model to HF: {e}")
201
+ raise
202
+
203
+ def _create_readme(self, metadata: Dict[str, Any], training_data_info: Dict[str, Any]) -> str:
204
+ """Create README content for the model"""
205
+ readme = f"""---
206
+ license: mit
207
+ tags:
208
+ - driver-behavior
209
+ - obd-data
210
+ - xgboost
211
+ - rlhf
212
+ - reinforcement-learning
213
+ ---
214
+
215
+ # Driver Behavior Classification Model (RLHF v{metadata['version']})
216
+
217
+ This model classifies driver behavior based on OBD (On-Board Diagnostics) sensor data using XGBoost.
218
+
219
+ ## Model Information
220
+
221
+ - **Model Type**: {metadata['model_type']}
222
+ - **Version**: {metadata['version']}
223
+ - **Created**: {metadata['created_at']}
224
+ - **Framework**: {metadata['framework']}
225
+ - **Task**: {metadata['task']}
226
+
227
+ ## Performance Metrics
228
+
229
+ """
230
+
231
+ for metric, value in metadata['performance_metrics'].items():
232
+ readme += f"- **{metric}**: {value:.4f}\n"
233
+
234
+ readme += f"""
235
+ ## Training Data
236
+
237
+ - **Datasets Used**: {len(training_data_info.get('datasets', []))}
238
+ - **Total Samples**: {training_data_info.get('total_samples', 'N/A')}
239
+ - **Training Date**: {training_data_info.get('training_date', 'N/A')}
240
+
241
+ ## Labels
242
+
243
+ The model predicts one of the following driver behavior categories:
244
+ """
245
+
246
+ for label in metadata['labels']:
247
+ readme += f"- {label}\n"
248
+
249
+ readme += """
250
+ ## Usage
251
+
252
+ ```python
253
+ import joblib
254
+ import pandas as pd
255
+
256
+ # Load the model
257
+ model = joblib.load('xgb_drivestyle_ul.pkl')
258
+ label_encoder = joblib.load('label_encoder_ul.pkl')
259
+ scaler = joblib.load('scaler_ul.pkl')
260
+
261
+ # Prepare your OBD data
262
+ # (Ensure features match the training data format)
263
+
264
+ # Make predictions
265
+ predictions = model.predict(scaled_data)
266
+ behavior_labels = label_encoder.inverse_transform(predictions)
267
+ ```
268
+
269
+ ## Files
270
+
271
+ - `xgb_drivestyle_ul.pkl`: Main XGBoost model
272
+ - `label_encoder_ul.pkl`: Label encoder for behavior categories
273
+ - `scaler_ul.pkl`: Feature scaler
274
+ - `metadata.json`: Model metadata and performance metrics
275
+
276
+ ## RLHF Training
277
+
278
+ This model was trained using Reinforcement Learning from Human Feedback (RLHF) to improve performance based on human-labeled data and feedback.
279
+ """
280
+
281
+ return readme
282
+
283
+ def save_training_log(self,
284
+ training_log: Dict[str, Any],
285
+ model_version: str) -> str:
286
+ """Save training log to Firebase storage"""
287
+ try:
288
+ # Import Firebase client
289
+ import sys
290
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
291
+ from data.firebase_saver import FirebaseSaver
292
+
293
+ # Create log entry
294
+ log_entry = {
295
+ "version": model_version,
296
+ "timestamp": datetime.now().isoformat(),
297
+ "log": training_log
298
+ }
299
+
300
+ # Save to Firebase
301
+ saver = FirebaseSaver()
302
+ # Note: We'll need to modify FirebaseSaver to support different prefixes
303
+ # For now, we'll save to a logs subdirectory
304
+ log_filename = f"training_log_v{model_version}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
305
+
306
+ # Create temporary file
307
+ temp_path = f"/tmp/{log_filename}"
308
+ with open(temp_path, 'w') as f:
309
+ json.dump(log_entry, f, indent=2)
310
+
311
+ # Upload to Firebase (we'll need to extend FirebaseSaver for this)
312
+ # For now, just log locally
313
+ logger.info(f"📝 Training log saved: {log_entry}")
314
+
315
+ return temp_path
316
+
317
+ except Exception as e:
318
+ logger.error(f"❌ Failed to save training log: {e}")
319
+ return ""
320
+
321
+ def save_complete_model(self,
322
+ model: Any,
323
+ label_encoder: Any,
324
+ scaler: Any,
325
+ model_version: str,
326
+ training_data_info: Dict[str, Any],
327
+ performance_metrics: Dict[str, float],
328
+ training_log: Dict[str, Any],
329
+ rlhf_metadata: Dict[str, Any] = None) -> Dict[str, str]:
330
+ """Complete model saving workflow"""
331
+ try:
332
+ # Create metadata
333
+ metadata = self._create_model_metadata(
334
+ model_type="xgboost_classifier",
335
+ training_data_info=training_data_info,
336
+ performance_metrics=performance_metrics,
337
+ model_version=model_version,
338
+ rlhf_metadata=rlhf_metadata
339
+ )
340
+
341
+ # Save locally
342
+ local_paths = self.save_model_locally(
343
+ model, label_encoder, scaler, model_version, metadata
344
+ )
345
+
346
+ # Save to Hugging Face Hub
347
+ hf_url = self.save_model_to_hf(
348
+ model, label_encoder, scaler, model_version, metadata, training_data_info
349
+ )
350
+
351
+ # Save training log
352
+ log_path = self.save_training_log(training_log, model_version)
353
+
354
+ result = {
355
+ "local_paths": local_paths,
356
+ "hf_url": hf_url,
357
+ "log_path": log_path,
358
+ "version": model_version,
359
+ "metadata": metadata
360
+ }
361
+
362
+ logger.info(f"✅ Complete model save successful: v{model_version}")
363
+ return result
364
+
365
+ except Exception as e:
366
+ logger.error(f"❌ Complete model save failed: {e}")
367
+ raise
368
+
369
+
370
+ def main():
371
+ """Test the saver functionality"""
372
+ try:
373
+ saver = ModelSaver()
374
+ print(f"ModelSaver initialized for repo: {saver.repo_id}")
375
+ print(f"Local model directory: {saver.local_model_dir}")
376
+ except Exception as e:
377
+ print(f"Failed to initialize ModelSaver: {e}")
378
+
379
+
380
+ if __name__ == "__main__":
381
+ main()
utils/download.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # download.py
2
+ # Download latest models from Hugging Face
3
+ import os, shutil, pathlib, sys
4
+ import json
5
+ from huggingface_hub import hf_hub_download, HfApi
6
+
7
+ def load_env_file():
8
+ """Load environment variables from .env file if it exists"""
9
+ env_path = pathlib.Path(__file__).parent.parent / ".env"
10
+ if env_path.exists():
11
+ with open(env_path, 'r') as f:
12
+ for line in f:
13
+ line = line.strip()
14
+ if line and not line.startswith('#') and '=' in line:
15
+ key, value = line.split('=', 1)
16
+ os.environ[key] = value
17
+ return True
18
+ return False
19
+
20
+ # Load .env file first before setting any environment variables
21
+ load_env_file()
22
+
23
+ REPO_ID = os.getenv("HF_MODEL_REPO", "BinKhoaLe1812/Driver_Behavior_OBD")
24
+ MODEL_DIR = pathlib.Path(os.getenv("MODEL_DIR", "/app/models/ul")).resolve()
25
+ FILES = ["label_encoder_ul.pkl", "scaler_ul.pkl", "xgb_drivestyle_ul.pkl"]
26
+
27
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
28
+
29
+ def get_latest_version():
30
+ """Get the latest model version from Hugging Face repo"""
31
+ try:
32
+ hf_token = os.getenv("HF_TOKEN")
33
+ if not hf_token:
34
+ print("⚠️ HF_TOKEN not set, using default model files")
35
+ return None
36
+
37
+ api = HfApi(token=hf_token)
38
+ repo_files = api.list_repo_files(
39
+ repo_id=REPO_ID,
40
+ repo_type="model"
41
+ )
42
+
43
+ print(f"🔍 Checking repository files...")
44
+ print(f"📁 Found {len(repo_files)} files in repository")
45
+
46
+ # Find version directories (v1.0, v1.1, etc.)
47
+ version_dirs = [f for f in repo_files if f.startswith('v') and '/' not in f]
48
+ print(f"📦 Found version directories: {version_dirs}")
49
+
50
+ # Also check for version directories with files inside
51
+ version_dirs_with_files = []
52
+ for f in repo_files:
53
+ if f.startswith('v') and '/' in f:
54
+ version_dir = f.split('/')[0]
55
+ if version_dir not in version_dirs_with_files:
56
+ version_dirs_with_files.append(version_dir)
57
+
58
+ if version_dirs_with_files:
59
+ print(f"📦 Found version directories with files: {version_dirs_with_files}")
60
+ version_dirs.extend(version_dirs_with_files)
61
+
62
+ versions = []
63
+
64
+ for v_dir in version_dirs:
65
+ try:
66
+ version_str = v_dir[1:] # Remove 'v' prefix
67
+ if '.' in version_str:
68
+ major, minor = version_str.split('.')
69
+ versions.append((int(major), int(minor), v_dir))
70
+ print(f"✅ Found version: {v_dir} (major={major}, minor={minor})")
71
+ except (ValueError, IndexError):
72
+ print(f"⚠️ Could not parse version: {v_dir}")
73
+ continue
74
+
75
+ if not versions:
76
+ print("📦 No versioned models found, checking for root files...")
77
+ # Check if files exist in root
78
+ root_files = [f for f in repo_files if f in FILES]
79
+ if root_files:
80
+ print(f"📁 Found root files: {root_files}")
81
+ return None # Use root files
82
+ else:
83
+ print("❌ No model files found in repository")
84
+ print("💡 Available files in repository:")
85
+ for f in sorted(repo_files):
86
+ print(f" - {f}")
87
+ return None
88
+
89
+ # Sort versions and get the latest
90
+ versions.sort()
91
+ latest_version = versions[-1][2] # Get the directory name
92
+ print(f"📦 Latest model version: {latest_version}")
93
+ return latest_version
94
+
95
+ except Exception as e:
96
+ print(f"⚠️ Failed to get latest version: {e}")
97
+ return None
98
+
99
+ def fetch_latest(fname: str, version_dir: str = None):
100
+ """Download the latest version of a model file"""
101
+ try:
102
+ if version_dir:
103
+ # Download from versioned directory
104
+ versioned_path = f"{version_dir}/{fname}"
105
+ print(f"📥 Downloading {fname} from {versioned_path}...")
106
+ src = hf_hub_download(repo_id=REPO_ID, filename=versioned_path, repo_type="model")
107
+ else:
108
+ # Download from root directory (fallback)
109
+ print(f"📥 Downloading {fname} from root directory...")
110
+ src = hf_hub_download(repo_id=REPO_ID, filename=fname, repo_type="model")
111
+
112
+ dst = MODEL_DIR / fname
113
+ shutil.copy2(src, dst)
114
+ print(f"✅ Downloaded {fname} → {dst}")
115
+ return True
116
+ except Exception as e:
117
+ print(f"❌ Failed to fetch {fname}: {e}")
118
+ if version_dir:
119
+ print(f" Tried path: {version_dir}/{fname}")
120
+ else:
121
+ print(f" Tried path: {fname}")
122
+ return False
123
+
124
+ def download_latest_models():
125
+ """Download the latest version of all model files"""
126
+ print("🔄 Checking for latest model version...")
127
+ latest_version = get_latest_version()
128
+
129
+ success_count = 0
130
+ for f in FILES:
131
+ if fetch_latest(f, latest_version):
132
+ success_count += 1
133
+
134
+ if success_count == len(FILES):
135
+ print(f"✅ Successfully downloaded all {len(FILES)} model files")
136
+ if latest_version:
137
+ print(f"📦 Using version: {latest_version}")
138
+ return True
139
+ else:
140
+ print(f"⚠️ Only {success_count}/{len(FILES)} files downloaded successfully")
141
+ return False
142
+
143
+ def fetch(fname: str):
144
+ """Legacy function for backward compatibility"""
145
+ return fetch_latest(fname)
146
+
147
+ def main():
148
+ """Download latest models"""
149
+ success = download_latest_models()
150
+ if not success:
151
+ sys.exit(1)
152
+
153
+ if __name__ == "__main__":
154
+ main()
utils/mount_drive.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gspread
4
+ import logging
5
+ from oauth2client.service_account import ServiceAccountCredentials
6
+
7
+ # Setup logging
8
+ logger = logging.getLogger("upload")
9
+ logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(asctime)s - %(message)s")
10
+
11
+ # Authenticate with GDrive using secret
12
+ logger.info("Authenticating to Google Drive...")
13
+ creds_json = os.getenv("GDRIVE_CREDENTIALS_JSON")
14
+ if not creds_json:
15
+ logger.error("GDRIVE_CREDENTIALS_JSON not found!")
16
+ exit(1)
17
+
18
+ try:
19
+ creds_dict = json.loads(creds_json)
20
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
21
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
22
+ client = gspread.authorize(creds)
23
+ logger.info("Authenticated with Google Drive")
24
+ except Exception as e:
25
+ logger.error(f"Failed to authenticate: {e}")
26
+ exit(1)
27
+
28
+ # Folder and files
29
+ upload_dir = "./cache/obd_data/cleaned"
30
+ if not os.path.exists(upload_dir):
31
+ logger.warning(f"Directory {upload_dir} does not exist.")
32
+ exit(0)
33
+
34
+ # Upload all .csv files
35
+ for file in os.listdir(upload_dir):
36
+ if file.endswith(".csv"):
37
+ try:
38
+ path = os.path.join(upload_dir, file)
39
+ logger.info(f"Uploading {file}...")
40
+ with open(path, "rb") as f:
41
+ client.import_csv(client.create(file).id, f.read())
42
+ logger.info(f"Uploaded {file}")
43
+ except Exception as e:
44
+ logger.error(f"Failed to upload {file}: {e}")
utils/ul_label.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ul_label.py
2
+ # Load UL models and predict driving style
3
+ import os, logging, pickle
4
+ import warnings
5
+ import joblib
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ # Import download functionality
10
+ import sys
11
+ sys.path.append(os.path.dirname(__file__))
12
+ from download import download_latest_models
13
+
14
+ log = logging.getLogger("ul-labeler")
15
+ log.setLevel(logging.INFO)
16
+
17
+ # Suppress version compatibility warnings in production
18
+ warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.base")
19
+ warnings.filterwarnings("ignore", category=UserWarning, module="xgboost.core")
20
+
21
+ MODEL_DIR = os.getenv("MODEL_DIR", "/app/models/ul")
22
+ LE_PATH = os.path.join(MODEL_DIR, "label_encoder_ul.pkl")
23
+ SC_PATH = os.path.join(MODEL_DIR, "scaler_ul.pkl")
24
+ XGB_PATH = os.path.join(MODEL_DIR, "xgb_drivestyle_ul.pkl")
25
+
26
+ SAFE_DROP = {
27
+ "timestamp","driving_style","ul_drivestyle","gt_drivestyle",
28
+ "session_id","imported_at","record_index"
29
+ }
30
+
31
+ def _load_any(path):
32
+ # Suppress version compatibility warnings for production
33
+ with warnings.catch_warnings():
34
+ warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
35
+ warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
36
+ warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
37
+ warnings.filterwarnings("ignore", category=FutureWarning, module="xgboost")
38
+ try:
39
+ model = joblib.load(path)
40
+ except Exception:
41
+ with open(path, "rb") as f:
42
+ model = pickle.load(f)
43
+
44
+ # Fix XGBoost compatibility issues for older trained models
45
+ if hasattr(model, 'get_booster'): # This is an XGBoost model
46
+ # Remove deprecated use_label_encoder attribute that causes issues in newer XGBoost versions
47
+ if hasattr(model, '__dict__'):
48
+ # Remove all deprecated attributes that cause issues
49
+ deprecated_attrs = [
50
+ 'use_label_encoder', '_le', '_label_encoder',
51
+ 'use_label_encoder_', '_le_', '_label_encoder_'
52
+ ]
53
+ for attr in deprecated_attrs:
54
+ model.__dict__.pop(attr, None)
55
+
56
+ # Set use_label_encoder to False for newer XGBoost versions
57
+ if hasattr(model, 'set_params'):
58
+ try:
59
+ model.set_params(use_label_encoder=False)
60
+ except Exception:
61
+ pass
62
+
63
+ return model
64
+
65
+ class ULLabeler:
66
+ _instance = None
67
+
68
+ def __init__(self, auto_download: bool = True):
69
+ # Auto-download latest models if enabled
70
+ if auto_download:
71
+ log.info("🔄 Checking for latest model version...")
72
+ try:
73
+ download_latest_models()
74
+ except Exception as e:
75
+ log.warning(f"⚠️ Failed to download latest models: {e}")
76
+
77
+ if not (os.path.exists(LE_PATH) and os.path.exists(SC_PATH) and os.path.exists(XGB_PATH)):
78
+ raise FileNotFoundError("Model files not found. Ensure download.py ran successfully.")
79
+ self.le = _load_any(LE_PATH)
80
+ self.scal = _load_any(SC_PATH)
81
+ self.clf = _load_any(XGB_PATH)
82
+
83
+ # Additional XGBoost compatibility fixes
84
+ self._fix_xgb_compatibility()
85
+
86
+ # Try to discover expected feature names from scaler or model
87
+ self.expected = None
88
+ if hasattr(self.scal, "feature_names_in_"):
89
+ self.expected = list(self.scal.feature_names_in_)
90
+ elif hasattr(self.clf, "feature_names_in_"):
91
+ self.expected = list(self.clf.feature_names_in_)
92
+
93
+ log.info(f"ULLabeler ready | expected_features={len(self.expected) if self.expected else 'unknown'}")
94
+
95
+ def _fix_xgb_compatibility(self):
96
+ """Fix XGBoost compatibility issues with older trained models."""
97
+ try:
98
+ # Check if this is an XGBoost classifier
99
+ if hasattr(self.clf, 'get_booster'):
100
+ # Remove deprecated attributes that cause issues in newer XGBoost versions
101
+ deprecated_attrs = [
102
+ 'use_label_encoder', '_le', '_label_encoder',
103
+ 'use_label_encoder_', '_le_', '_label_encoder_'
104
+ ]
105
+ for attr in deprecated_attrs:
106
+ if hasattr(self.clf, attr):
107
+ try:
108
+ delattr(self.clf, attr)
109
+ except (AttributeError, TypeError):
110
+ pass
111
+
112
+ # Set use_label_encoder to False for newer XGBoost versions
113
+ if hasattr(self.clf, 'set_params'):
114
+ try:
115
+ self.clf.set_params(use_label_encoder=False)
116
+ except Exception:
117
+ pass
118
+
119
+ # Ensure the model is properly configured for prediction
120
+ if hasattr(self.clf, 'n_classes_') and self.clf.n_classes_ is None:
121
+ # Try to infer number of classes from the label encoder
122
+ if hasattr(self.le, 'classes_'):
123
+ self.clf.n_classes_ = len(self.le.classes_)
124
+
125
+ # For newer XGBoost versions, ensure the model is properly initialized
126
+ if hasattr(self.clf, '_le') and self.clf._le is None:
127
+ self.clf._le = None
128
+
129
+ log.info("XGBoost compatibility fixes applied successfully")
130
+ except Exception as e:
131
+ log.warning(f"XGBoost compatibility fix failed: {e}")
132
+
133
+ @classmethod
134
+ def get(cls, auto_download: bool = True):
135
+ if cls._instance is None:
136
+ cls._instance = ULLabeler(auto_download=auto_download)
137
+ return cls._instance
138
+
139
+ def _prepare(self, df: pd.DataFrame):
140
+ # numeric only + drop non-feature columns
141
+ cols = [c for c in df.columns if c not in SAFE_DROP and pd.api.types.is_numeric_dtype(df[c])]
142
+ X = df[cols].copy()
143
+
144
+ # ensure required features
145
+ if self.expected:
146
+ for c in self.expected:
147
+ if c not in X.columns:
148
+ X[c] = 0.0
149
+ X = X[self.expected] # align order
150
+ X = X.fillna(0)
151
+
152
+ # scale
153
+ try:
154
+ Xs = self.scal.transform(X if hasattr(self.scal, "feature_names_in_") else X.values)
155
+ except Exception as e:
156
+ log.warning(f"Scaler transform failed ({e}); using raw features.")
157
+ Xs = X.values
158
+ return Xs
159
+
160
+ def predict_df(self, df: pd.DataFrame) -> np.ndarray:
161
+ Xs = self._prepare(df)
162
+ try:
163
+ yhat = self.clf.predict(Xs)
164
+ except (AttributeError, TypeError) as e:
165
+ if 'use_label_encoder' in str(e) or 'label_encoder' in str(e):
166
+ # Last resort: try to fix the model and retry
167
+ log.warning("XGBoost compatibility issue detected, attempting fix...")
168
+ try:
169
+ # Remove all problematic attributes
170
+ deprecated_attrs = [
171
+ 'use_label_encoder', '_le', '_label_encoder',
172
+ 'use_label_encoder_', '_le_', '_label_encoder_'
173
+ ]
174
+ for attr in deprecated_attrs:
175
+ if hasattr(self.clf, attr):
176
+ try:
177
+ delattr(self.clf, attr)
178
+ except (AttributeError, TypeError):
179
+ pass
180
+
181
+ # Set use_label_encoder to False
182
+ if hasattr(self.clf, 'set_params'):
183
+ try:
184
+ self.clf.set_params(use_label_encoder=False)
185
+ except Exception:
186
+ pass
187
+
188
+ # Retry prediction
189
+ yhat = self.clf.predict(Xs)
190
+ except Exception as retry_e:
191
+ log.error(f"Failed to fix XGBoost compatibility: {retry_e}")
192
+ raise e
193
+ else:
194
+ raise e
195
+
196
+ try:
197
+ return self.le.inverse_transform(yhat)
198
+ except Exception:
199
+ return yhat
200
+
201
+ def predict_csv(self, csv_path: str) -> pd.DataFrame:
202
+ df = pd.read_csv(csv_path)
203
+ y = self.predict_df(df)
204
+ out = df.copy()
205
+ out["driving_style"] = y
206
+ return out