Aayan Mishra commited on
Commit
7bfc1d5
ยท
verified ยท
1 Parent(s): 7e1c436

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +825 -0
main.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ATAR Prediction System with ML Ensemble
3
+ All-in-one Gradio app with training, inference, and HF Model Repo integration
4
+ Optimized for ZeroGPU (no persistent storage needed)
5
+
6
+ Author: Victor Academy
7
+ """
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import pandas as pd
12
+ import json
13
+ import os
14
+ from typing import List, Dict, Any, Tuple
15
+ import warnings
16
+ warnings.filterwarnings('ignore')
17
+
18
+ # ZeroGPU support for Hugging Face Spaces
19
+ try:
20
+ import spaces
21
+ ZEROGPU_AVAILABLE = True
22
+ print("โœ… ZeroGPU support enabled")
23
+ except ImportError:
24
+ ZEROGPU_AVAILABLE = False
25
+ print("โ„น๏ธ Running without ZeroGPU (local mode)")
26
+
27
+ # ML Libraries
28
+ try:
29
+ from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
30
+ from sklearn.linear_model import Ridge
31
+ from sklearn.preprocessing import StandardScaler
32
+ from sklearn.model_selection import train_test_split
33
+ import joblib
34
+ except ImportError:
35
+ print("โš ๏ธ Installing scikit-learn...")
36
+ os.system("pip install scikit-learn joblib")
37
+ from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
38
+ from sklearn.linear_model import Ridge
39
+ from sklearn.preprocessing import StandardScaler
40
+ from sklearn.model_test_split import train_test_split
41
+ import joblib
42
+
43
+ # Hugging Face Hub for model storage
44
+ try:
45
+ from huggingface_hub import HfApi, login, hf_hub_download
46
+ except ImportError:
47
+ print("โš ๏ธ Installing huggingface_hub...")
48
+ os.system("pip install huggingface_hub")
49
+ from huggingface_hub import HfApi, login, hf_hub_download
50
+
51
+ # ============================================
52
+ # CONFIGURATION
53
+ # ============================================
54
+
55
+ HF_MODEL_REPO = "Spestly/VAML-ATAR" # Your HF model repo
56
+ FEATURE_COUNT = 18
57
+ MODEL_VERSION = "v1.0.0" # Semantic versioning: major.minor.patch
58
+
59
+ # HF Token - REQUIRED for training (set as environment variable in HF Space settings)
60
+ # Get from: https://huggingface.co/settings/tokens (write access needed)
61
+ # In HF Space: Settings โ†’ Variables and secrets โ†’ Add: HF_TOKEN = hf_xxxxx
62
+ HF_TOKEN = os.environ.get('HF_TOKEN', None)
63
+
64
+ if not HF_TOKEN:
65
+ print("โš ๏ธ Warning: HF_TOKEN not set! Training will fail.")
66
+ print(" Set HF_TOKEN environment variable in Space settings.")
67
+ else:
68
+ print("โœ… HF_TOKEN found")
69
+
70
+ # Subject scaling data (2024 HSC data)
71
+ SUBJECT_SCALING_DATA = {
72
+ 'Mathematics Extension 2': {'scaling_factor': 1.1943, 'mean': 71.2, 'std': 12.5, 'difficulty': 'very_hard'},
73
+ 'Mathematics Extension 1': {'scaling_factor': 1.1547, 'mean': 69.8, 'std': 13.1, 'difficulty': 'hard'},
74
+ 'Mathematics Advanced': {'scaling_factor': 1.0821, 'mean': 72.5, 'std': 11.8, 'difficulty': 'medium'},
75
+ 'Physics': {'scaling_factor': 1.1037, 'mean': 70.3, 'std': 12.2, 'difficulty': 'hard'},
76
+ 'Chemistry': {'scaling_factor': 1.0956, 'mean': 71.1, 'std': 11.9, 'difficulty': 'hard'},
77
+ 'Biology': {'scaling_factor': 1.0234, 'mean': 73.8, 'std': 10.5, 'difficulty': 'medium'},
78
+ 'English Advanced': {'scaling_factor': 1.0000, 'mean': 75.2, 'std': 9.8, 'difficulty': 'medium'},
79
+ 'English Standard': {'scaling_factor': 0.9234, 'mean': 68.5, 'std': 11.2, 'difficulty': 'easy'},
80
+ 'Economics': {'scaling_factor': 1.0645, 'mean': 72.8, 'std': 11.3, 'difficulty': 'medium'},
81
+ 'Business Studies': {'scaling_factor': 0.9856, 'mean': 71.2, 'std': 10.8, 'difficulty': 'medium'},
82
+ 'Legal Studies': {'scaling_factor': 0.9923, 'mean': 72.5, 'std': 10.2, 'difficulty': 'medium'},
83
+ 'Modern History': {'scaling_factor': 1.0112, 'mean': 73.1, 'std': 10.6, 'difficulty': 'medium'},
84
+ 'Ancient History': {'scaling_factor': 1.0089, 'mean': 72.9, 'std': 10.4, 'difficulty': 'medium'},
85
+ 'PDHPE': {'scaling_factor': 0.9639, 'mean': 70.8, 'std': 11.5, 'difficulty': 'easy'},
86
+ 'Software Design & Development': {'scaling_factor': 1.0423, 'mean': 71.6, 'std': 12.1, 'difficulty': 'medium'},
87
+ 'Visual Arts': {'scaling_factor': 0.9734, 'mean': 76.2, 'std': 8.9, 'difficulty': 'easy'},
88
+ 'Music 2': {'scaling_factor': 1.0567, 'mean': 77.5, 'std': 9.2, 'difficulty': 'medium'},
89
+ 'Geography': {'scaling_factor': 0.9912, 'mean': 72.3, 'std': 10.7, 'difficulty': 'medium'},
90
+ 'Industrial Technology': {'scaling_factor': 0.9523, 'mean': 69.7, 'std': 11.8, 'difficulty': 'easy'},
91
+ }
92
+
93
+ # ============================================
94
+ # FEATURE ENGINEERING
95
+ # ============================================
96
+
97
+ def extract_features(subjects: List[Dict]) -> np.ndarray:
98
+ """
99
+ Extract 18 features from subject data
100
+
101
+ Features:
102
+ - 10 subject marks (padded with 0 if fewer subjects)
103
+ - Average mark
104
+ - Standard deviation
105
+ - High-scaling subject count
106
+ - Overall trend score
107
+ - Assessment count score
108
+ - Top mark quality
109
+ - Bottom mark quality
110
+ - Has good English flag
111
+ """
112
+ # Get top 10 subjects by mark
113
+ sorted_subjects = sorted(subjects, key=lambda x: x.get('raw_mark', 0), reverse=True)[:10]
114
+
115
+ # Extract marks
116
+ marks = [s.get('raw_mark', 0) for s in sorted_subjects]
117
+ while len(marks) < 10:
118
+ marks.append(0)
119
+
120
+ # Normalize to 0-1
121
+ marks_normalized = [m / 100.0 for m in marks[:10]]
122
+
123
+ # Calculate derived features
124
+ valid_marks = [m for m in marks if m > 0]
125
+ avg_mark = np.mean(valid_marks) if valid_marks else 0
126
+ std_dev = np.std(valid_marks) if len(valid_marks) > 1 else 0
127
+
128
+ # Count high-scaling subjects (factor > 1.05)
129
+ high_scaling_count = sum(1 for s in sorted_subjects
130
+ if SUBJECT_SCALING_DATA.get(s.get('subject_name', ''), {}).get('scaling_factor', 1.0) > 1.05)
131
+
132
+ # Trend score (0-1)
133
+ trend_map = {'improving': 1.0, 'stable': 0.5, 'declining': 0.0}
134
+ trends = [trend_map.get(s.get('trend', 'stable'), 0.5) for s in sorted_subjects]
135
+ trend_score = np.mean(trends) if trends else 0.5
136
+
137
+ # Assessment count score (normalized)
138
+ assessment_counts = [s.get('assessment_count', 1) for s in sorted_subjects]
139
+ assessment_score = min(np.mean(assessment_counts) / 10.0, 1.0)
140
+
141
+ # Quality metrics
142
+ top_mark_quality = marks[0] / 90.0 if marks[0] > 0 else 0
143
+ bottom_mark_quality = marks[-1] / 90.0 if marks[-1] > 0 else 0
144
+
145
+ # English quality flag
146
+ english_subjects = [s for s in sorted_subjects if 'English' in s.get('subject_name', '')]
147
+ has_good_english = 1.0 if english_subjects and english_subjects[0].get('raw_mark', 0) >= 80 else 0.0
148
+
149
+ # Combine features
150
+ features = marks_normalized + [
151
+ avg_mark / 100.0,
152
+ min(std_dev / 20.0, 1.0),
153
+ high_scaling_count / 10.0,
154
+ trend_score,
155
+ assessment_score,
156
+ top_mark_quality,
157
+ bottom_mark_quality,
158
+ has_good_english
159
+ ]
160
+
161
+ return np.array(features, dtype=np.float32)
162
+
163
+ # ============================================
164
+ # DATA GENERATION (for training)
165
+ # ============================================
166
+
167
+ def generate_synthetic_data(n_samples: int = 10000) -> Tuple[np.ndarray, np.ndarray]:
168
+ """
169
+ Generate synthetic ATAR training data using UAC formula
170
+ """
171
+ np.random.seed(42)
172
+
173
+ X = []
174
+ y = []
175
+
176
+ for _ in range(n_samples):
177
+ # Generate 10 subject marks
178
+ subject_marks = np.random.normal(73, 10, 10)
179
+ subject_marks = np.clip(subject_marks, 40, 100)
180
+ subject_marks = np.sort(subject_marks)[::-1] # Sort descending
181
+
182
+ # Derived features
183
+ avg_mark = np.mean(subject_marks)
184
+ std_dev = np.std(subject_marks)
185
+ high_scaling_count = np.random.randint(0, 6)
186
+ trend_score = np.random.uniform(0, 1)
187
+ assessment_count = np.random.uniform(0, 1)
188
+ top_mark_quality = min(subject_marks[0] / 90, 1)
189
+ bottom_mark_quality = min(subject_marks[-1] / 90, 1)
190
+ has_good_english = 1 if subject_marks[0] >= 80 else 0
191
+
192
+ # Calculate ATAR using UAC formula
193
+ # Aggregate scaled marks (simplified)
194
+ aggregate = sum([m * 2 / 50.0 for m in subject_marks])
195
+
196
+ # Base ATAR calculation
197
+ base_atar = 99.95 * (aggregate / 500) ** 0.85
198
+
199
+ # Adjustments
200
+ atar = base_atar + (high_scaling_count - 2.5) * 0.5
201
+ atar += (trend_score - 0.5) * 2
202
+ atar += np.random.normal(0, 0.5) # Add noise
203
+ atar = np.clip(atar, 30, 99.95)
204
+
205
+ # Features (normalized)
206
+ features = list(subject_marks / 100) + [
207
+ avg_mark / 100,
208
+ min(std_dev / 20, 1),
209
+ high_scaling_count / 10,
210
+ trend_score,
211
+ assessment_count,
212
+ top_mark_quality,
213
+ bottom_mark_quality,
214
+ has_good_english
215
+ ]
216
+
217
+ X.append(features)
218
+ y.append(atar)
219
+
220
+ return np.array(X), np.array(y)
221
+
222
+ # ============================================
223
+ # MODEL TRAINING
224
+ # ============================================
225
+
226
+ class ATARMLEnsemble:
227
+ """
228
+ ML Ensemble for ATAR prediction
229
+ Uses Gradient Boosting + Random Forest + Ridge Regression
230
+ """
231
+ def __init__(self):
232
+ self.scaler = StandardScaler()
233
+ self.models = {
234
+ 'gb': GradientBoostingRegressor(n_estimators=200, max_depth=5, learning_rate=0.1, random_state=42),
235
+ 'rf': RandomForestRegressor(n_estimators=200, max_depth=10, random_state=42),
236
+ 'ridge': Ridge(alpha=1.0, random_state=42)
237
+ }
238
+ self.weights = {'gb': 0.5, 'rf': 0.3, 'ridge': 0.2} # Ensemble weights
239
+ self.is_trained = False
240
+
241
+ def train(self, X, y, X_test=None, y_test=None):
242
+ """Train all models in the ensemble"""
243
+ print(f"๐Ÿš€ Training on {len(X)} samples...")
244
+
245
+ # Scale features
246
+ X_scaled = self.scaler.fit_transform(X)
247
+
248
+ # Train each model
249
+ for name, model in self.models.items():
250
+ print(f"Training {name}...")
251
+ model.fit(X_scaled, y)
252
+
253
+ self.is_trained = True
254
+ self.training_samples = len(X)
255
+
256
+ # Store metrics for versioning
257
+ train_pred = self.predict(X)
258
+ self.train_mae = np.mean(np.abs(train_pred - y))
259
+
260
+ if X_test is not None and y_test is not None:
261
+ test_pred = self.predict(X_test)
262
+ self.test_mae = np.mean(np.abs(test_pred - y_test))
263
+ else:
264
+ self.test_mae = None
265
+
266
+ print("โœ… Ensemble training complete!")
267
+
268
+ def predict(self, X):
269
+ """Predict using weighted ensemble"""
270
+ if not self.is_trained:
271
+ raise ValueError("Model not trained! Train first or load from HF.")
272
+
273
+ X_scaled = self.scaler.transform(X)
274
+
275
+ # Get predictions from each model
276
+ predictions = {}
277
+ for name, model in self.models.items():
278
+ predictions[name] = model.predict(X_scaled)
279
+
280
+ # Weighted average
281
+ final_pred = sum(predictions[name] * self.weights[name] for name in self.models.keys())
282
+
283
+ return final_pred
284
+
285
+ def save_local(self, path='models'):
286
+ """Save models locally"""
287
+ os.makedirs(path, exist_ok=True)
288
+ joblib.dump(self.scaler, f'{path}/scaler.pkl')
289
+ for name, model in self.models.items():
290
+ joblib.dump(model, f'{path}/{name}.pkl')
291
+ joblib.dump(self.weights, f'{path}/weights.pkl')
292
+ print(f"โœ… Models saved to {path}/")
293
+
294
+ def load_local(self, path='models'):
295
+ """Load models from local path"""
296
+ self.scaler = joblib.load(f'{path}/scaler.pkl')
297
+ for name in self.models.keys():
298
+ self.models[name] = joblib.load(f'{path}/{name}.pkl')
299
+ self.weights = joblib.load(f'{path}/weights.pkl')
300
+ self.is_trained = True
301
+ print(f"โœ… Models loaded from {path}/")
302
+
303
+ # Global model instance
304
+ ensemble = ATARMLEnsemble()
305
+
306
+ # ============================================
307
+ # HUGGING FACE INTEGRATION
308
+ # ============================================
309
+
310
+ def upload_to_hf(version: str = None, repo_name: str = HF_MODEL_REPO):
311
+ """
312
+ Upload trained models to HF Model Repo with versioning
313
+
314
+ Versioning strategy:
315
+ - models/{version}/ โ†’ Specific version (e.g., models/v1.0.0/)
316
+ - models/latest/ โ†’ Always points to newest version
317
+ - metadata.json โ†’ Tracks all versions and metrics
318
+ """
319
+ try:
320
+ # Check if HF_TOKEN is set
321
+ if not HF_TOKEN:
322
+ return "โŒ HF_TOKEN not set! Please set it as environment variable in Space settings."
323
+
324
+ # Login to HF
325
+ login(token=HF_TOKEN)
326
+ api = HfApi()
327
+
328
+ # Use provided version or generate from timestamp
329
+ if version is None:
330
+ from datetime import datetime
331
+ version = datetime.now().strftime("v%Y%m%d_%H%M%S")
332
+
333
+ # Create repo if doesn't exist
334
+ try:
335
+ api.create_repo(repo_id=repo_name, repo_type="model", private=False)
336
+ print(f"โœ… Created repo: {repo_name}")
337
+ except:
338
+ print(f"โ„น๏ธ Repo {repo_name} already exists")
339
+
340
+ # Upload model files to versioned folder
341
+ files = ['scaler.pkl', 'gb.pkl', 'rf.pkl', 'ridge.pkl', 'weights.pkl']
342
+
343
+ print(f"๐Ÿ“ค Uploading version: {version}")
344
+
345
+ # Upload to specific version folder
346
+ for file in files:
347
+ api.upload_file(
348
+ path_or_fileobj=f'models/{file}',
349
+ path_in_repo=f'models/{version}/{file}',
350
+ repo_id=repo_name,
351
+ repo_type="model"
352
+ )
353
+
354
+ # Also upload to 'latest' folder (for easy access)
355
+ for file in files:
356
+ api.upload_file(
357
+ path_or_fileobj=f'models/{file}',
358
+ path_in_repo=f'models/latest/{file}',
359
+ repo_id=repo_name,
360
+ repo_type="model"
361
+ )
362
+
363
+ # Download existing metadata if it exists
364
+ try:
365
+ import tempfile
366
+ temp_dir = tempfile.mkdtemp()
367
+ metadata_path = hf_hub_download(
368
+ repo_id=repo_name,
369
+ filename="metadata.json",
370
+ repo_type="model",
371
+ cache_dir=temp_dir
372
+ )
373
+ with open(metadata_path, 'r') as f:
374
+ metadata = json.load(f)
375
+ except:
376
+ metadata = {
377
+ "versions": [],
378
+ "latest_version": None,
379
+ "model_type": "ML Ensemble (Gradient Boosting + Random Forest + Ridge)",
380
+ "feature_count": FEATURE_COUNT
381
+ }
382
+
383
+ # Add new version to metadata
384
+ from datetime import datetime
385
+ new_version_info = {
386
+ "version": version,
387
+ "timestamp": datetime.now().isoformat(),
388
+ "training_samples": getattr(ensemble, 'training_samples', "unknown"),
389
+ "train_mae": getattr(ensemble, 'train_mae', None),
390
+ "test_mae": getattr(ensemble, 'test_mae', None),
391
+ "model_files": files
392
+ }
393
+
394
+ metadata["versions"].append(new_version_info)
395
+ metadata["latest_version"] = version
396
+ metadata["total_versions"] = len(metadata["versions"])
397
+
398
+ # Save updated metadata locally
399
+ with open('models/metadata.json', 'w') as f:
400
+ json.dump(metadata, f, indent=2)
401
+
402
+ # Upload metadata
403
+ api.upload_file(
404
+ path_or_fileobj='models/metadata.json',
405
+ path_in_repo='metadata.json',
406
+ repo_id=repo_name,
407
+ repo_type="model"
408
+ )
409
+
410
+ return f"""โœ… Models uploaded successfully!
411
+
412
+ ๐Ÿ“ฆ Version: {version}
413
+ ๐Ÿ”— Repo: https://huggingface.co/{repo_name}
414
+ ๐Ÿ“Š Total versions: {len(metadata['versions'])}
415
+
416
+ Access:
417
+ - Latest: models/latest/
418
+ - This version: models/{version}/
419
+ - All versions: See metadata.json
420
+ """
421
+ except Exception as e:
422
+ return f"โŒ Upload failed: {str(e)}"
423
+
424
+ def download_from_hf(version: str = "latest", repo_name: str = HF_MODEL_REPO, token: str = None):
425
+ """
426
+ Download models from HF Model Repo
427
+
428
+ Args:
429
+ version: Version to load ('latest', 'v1.0.0', etc.)
430
+ repo_name: HF model repo name
431
+ token: HF token (optional - only needed for private repos)
432
+ """
433
+ try:
434
+ os.makedirs('models', exist_ok=True)
435
+
436
+ # Use provided token, or environment variable, or None (for public repos)
437
+ auth_token = token or HF_TOKEN
438
+
439
+ # Determine path based on version
440
+ path_prefix = f"models/{version}/"
441
+
442
+ files = ['scaler.pkl', 'gb.pkl', 'rf.pkl', 'ridge.pkl', 'weights.pkl']
443
+
444
+ print(f"๐Ÿ“ฅ Downloading version: {version}")
445
+ if auth_token:
446
+ print("๐Ÿ”’ Using authentication (private repo)")
447
+ else:
448
+ print("๐ŸŒ No token - assuming public repo")
449
+
450
+ for file in files:
451
+ local_path = hf_hub_download(
452
+ repo_id=repo_name,
453
+ filename=f"{path_prefix}{file}",
454
+ repo_type="model",
455
+ cache_dir='models',
456
+ token=auth_token # โ† Added token support
457
+ )
458
+ # Copy to models/ directory
459
+ import shutil
460
+ shutil.copy(local_path, f'models/{file}')
461
+
462
+ # Load into ensemble
463
+ ensemble.load_local('models')
464
+
465
+ # Try to get version info from metadata
466
+ try:
467
+ metadata_path = hf_hub_download(
468
+ repo_id=repo_name,
469
+ filename="metadata.json",
470
+ repo_type="model",
471
+ cache_dir='models',
472
+ token=auth_token # โ† Added token support
473
+ )
474
+ with open(metadata_path, 'r') as f:
475
+ metadata = json.load(f)
476
+
477
+ version_info = next((v for v in metadata["versions"] if v["version"] == version), None)
478
+
479
+ info_str = f"""โœ… Models loaded successfully!
480
+
481
+ ๐Ÿ“ฆ Version: {version}
482
+ ๐Ÿ“… Trained: {version_info.get('timestamp', 'Unknown') if version_info else 'Unknown'}
483
+ ๐Ÿ“Š Train MAE: {version_info.get('train_mae', 'N/A') if version_info else 'N/A'} ATAR points
484
+ ๐Ÿ“Š Test MAE: {version_info.get('test_mae', 'N/A') if version_info else 'N/A'} ATAR points
485
+ ๐Ÿ”— Repo: https://huggingface.co/{repo_name}
486
+ """
487
+ return info_str
488
+ except:
489
+ return f"โœ… Models loaded from https://huggingface.co/{repo_name} ({version})"
490
+
491
+ except Exception as e:
492
+ return f"โŒ Download failed: {str(e)}\nTrain the model first or check version name!"
493
+
494
+ # ============================================
495
+ # PREDICTION LOGIC
496
+ # ============================================
497
+
498
+ def predict_atar(subjects: List[Dict]) -> Dict[str, Any]:
499
+ """
500
+ Predict ATAR using ML ensemble
501
+ Auto-loads model from HF if not loaded
502
+ """
503
+ # Auto-load model if not trained
504
+ if not ensemble.is_trained:
505
+ result = download_from_hf()
506
+ if "โŒ" in result:
507
+ return {
508
+ 'error': 'Model not trained or available. Please train first!',
509
+ 'predicted_atar': 0,
510
+ 'confidence': 0
511
+ }
512
+
513
+ # Extract features
514
+ features = extract_features(subjects)
515
+ X = features.reshape(1, -1)
516
+
517
+ # Predict
518
+ predicted_atar = ensemble.predict(X)[0]
519
+ predicted_atar = np.clip(predicted_atar, 30, 99.95)
520
+
521
+ # Calculate confidence (based on data quality)
522
+ confidence = calculate_confidence(subjects)
523
+
524
+ # Generate insights
525
+ insights = generate_insights(subjects, predicted_atar)
526
+ recommendations = generate_recommendations(subjects, predicted_atar)
527
+
528
+ return {
529
+ 'predicted_atar': round(predicted_atar, 2),
530
+ 'confidence': round(confidence, 2),
531
+ 'insights': insights,
532
+ 'recommendations': recommendations
533
+ }
534
+
535
+ def calculate_confidence(subjects: List[Dict]) -> float:
536
+ """Calculate prediction confidence based on data quality"""
537
+ if not subjects:
538
+ return 0.0
539
+
540
+ # Factors affecting confidence
541
+ assessment_completeness = min(sum(s.get('assessment_count', 0) for s in subjects) / (len(subjects) * 5), 1.0)
542
+ subject_count_factor = min(len(subjects) / 10, 1.0)
543
+ has_trends = sum(1 for s in subjects if 'trend' in s) / len(subjects)
544
+
545
+ confidence = 0.4 * assessment_completeness + 0.3 * subject_count_factor + 0.3 * has_trends
546
+ return confidence
547
+
548
+ def generate_insights(subjects: List[Dict], predicted_atar: float) -> List[str]:
549
+ """Generate insights based on subject performance"""
550
+ insights = []
551
+
552
+ # Performance level
553
+ if predicted_atar >= 95:
554
+ insights.append("๐ŸŽฏ Excellent performance! You're on track for elite universities.")
555
+ elif predicted_atar >= 85:
556
+ insights.append("๐Ÿ“ˆ Strong performance! Many competitive courses within reach.")
557
+ elif predicted_atar >= 75:
558
+ insights.append("โœ… Solid foundation! Focus on improvement areas for better outcomes.")
559
+ else:
560
+ insights.append("๐Ÿ’ช Room for growth! Strategic improvement can boost your ATAR significantly.")
561
+
562
+ # Subject mix analysis
563
+ high_scaling = [s for s in subjects if SUBJECT_SCALING_DATA.get(s.get('subject_name', ''), {}).get('scaling_factor', 1.0) > 1.05]
564
+ if len(high_scaling) >= 3:
565
+ insights.append(f"โญ Your {len(high_scaling)} high-scaling subjects will boost your ATAR!")
566
+
567
+ return insights
568
+
569
+ def generate_recommendations(subjects: List[Dict], predicted_atar: float) -> List[str]:
570
+ """Generate improvement recommendations"""
571
+ recommendations = []
572
+
573
+ # Find weakest subjects
574
+ sorted_subjects = sorted(subjects, key=lambda x: x.get('raw_mark', 0))
575
+ if sorted_subjects:
576
+ weakest = sorted_subjects[0]
577
+ recommendations.append(f"๐ŸŽฏ Focus on {weakest.get('subject_name', 'weakest subject')} - raising this by 5 marks could add ~1 ATAR point")
578
+
579
+ # Suggest high-scaling subjects
580
+ low_scaling = [s for s in subjects if SUBJECT_SCALING_DATA.get(s.get('subject_name', ''), {}).get('scaling_factor', 1.0) < 0.98]
581
+ if low_scaling:
582
+ recommendations.append(f"โš–๏ธ Consider if {low_scaling[0].get('subject_name')} is in your best 10 units")
583
+
584
+ return recommendations
585
+
586
+ # ============================================
587
+ # GRADIO INTERFACE
588
+ # ============================================
589
+
590
+ @spaces.GPU(duration=120) if ZEROGPU_AVAILABLE else lambda x: x
591
+ def train_model_interface(n_samples: int, version: str = None):
592
+ """Train model and upload to HF with versioning"""
593
+ try:
594
+ # Generate data
595
+ yield "๐Ÿ“Š Generating synthetic training data..."
596
+ X, y = generate_synthetic_data(n_samples)
597
+
598
+ # Split
599
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
600
+
601
+ # Train
602
+ yield "๐Ÿš€ Training ML ensemble (Gradient Boosting + Random Forest + Ridge)..."
603
+ ensemble.train(X_train, y_train, X_test, y_test)
604
+
605
+ # Evaluate
606
+ train_pred = ensemble.predict(X_train)
607
+ test_pred = ensemble.predict(X_test)
608
+
609
+ train_mae = np.mean(np.abs(train_pred - y_train))
610
+ test_mae = np.mean(np.abs(test_pred - y_test))
611
+
612
+ yield f"โœ… Training complete!\n\n๐Ÿ“Š Results:\n- Train MAE: {train_mae:.2f} ATAR points\n- Test MAE: {test_mae:.2f} ATAR points\n- Training samples: {len(X_train):,}\n\n๐Ÿ’พ Saving models locally..."
613
+
614
+ # Save locally
615
+ ensemble.save_local('models')
616
+
617
+ # Upload to HF with versioning
618
+ yield f"โœ… Models saved!\n\nโ˜๏ธ Uploading to Hugging Face with versioning..."
619
+
620
+ # Auto-generate version if not provided
621
+ if not version or version.strip() == "":
622
+ from datetime import datetime
623
+ version = datetime.now().strftime("v%Y%m%d_%H%M%S")
624
+
625
+ result = upload_to_hf(version=version)
626
+ yield f"โœ… Training complete!\n\n๐Ÿ“Š Results:\n- Train MAE: {train_mae:.2f} ATAR points\n- Test MAE: {test_mae:.2f} ATAR points\n- Training samples: {len(X_train):,}\n\n{result}\n\n๐ŸŽ‰ Model ready for inference!"
627
+
628
+ except Exception as e:
629
+ yield f"โŒ Training failed: {str(e)}"
630
+
631
+ @spaces.GPU(duration=5) if ZEROGPU_AVAILABLE else lambda x: x
632
+ def predict_interface(subjects_json: str):
633
+ """Predict ATAR from JSON input"""
634
+ try:
635
+ subjects = json.loads(subjects_json)
636
+ result = predict_atar(subjects)
637
+ return json.dumps(result, indent=2)
638
+ except Exception as e:
639
+ return json.dumps({'error': str(e)})
640
+
641
+ # ============================================
642
+ # GRADIO APP
643
+ # ============================================
644
+
645
+ with gr.Blocks(title="ATAR Prediction ML Ensemble", theme=gr.themes.Soft()) as app:
646
+ gr.Markdown("""
647
+ # ๐ŸŽ“ ATAR Prediction System (ML Ensemble)
648
+ **Powered by Gradient Boosting + Random Forest + Ridge Regression**
649
+
650
+ ### Features:
651
+ - ๐Ÿš€ Train on ZeroGPU with automatic HF Model Repo upload
652
+ - ๐Ÿ”ฎ Predict ATAR from subject marks (auto-loads model from HF)
653
+ - โ˜๏ธ No persistent storage needed - models live in HF Model Repo
654
+ """)
655
+
656
+ with gr.Tabs():
657
+ # Tab 1: Training
658
+ with gr.Tab("๐Ÿ‹๏ธ Train Model"):
659
+ gr.Markdown("### Train ML Ensemble & Upload to Hugging Face")
660
+
661
+ with gr.Row():
662
+ n_samples_input = gr.Slider(1000, 50000, value=10000, step=1000, label="Training Samples")
663
+ version_input = gr.Textbox(
664
+ label="Version (optional - auto-generated if empty)",
665
+ placeholder="v1.0.0 or leave empty for timestamp",
666
+ value=""
667
+ )
668
+
669
+ train_btn = gr.Button("๐Ÿš€ Train & Upload to HF", variant="primary", size="lg")
670
+ train_output = gr.Textbox(label="Training Progress", lines=12)
671
+
672
+ train_btn.click(
673
+ fn=train_model_interface,
674
+ inputs=[n_samples_input, version_input],
675
+ outputs=train_output
676
+ )
677
+
678
+ gr.Markdown("""
679
+ **Instructions:**
680
+ 1. Set `HF_TOKEN` environment variable in Space settings (write access)
681
+ - Go to Space Settings โ†’ Variables and secrets
682
+ - Add secret: `HF_TOKEN` = your token from https://huggingface.co/settings/tokens
683
+ 2. (Optional) Specify version like `v1.0.0`, `v1.1.0`, etc. or leave empty for auto timestamp
684
+ 3. Click "Train & Upload to HF"
685
+ 4. Model will be uploaded to `victor-academy/atar-predictor-ensemble`
686
+ 5. Each training creates a new version - no overwrites!
687
+
688
+ **Versioning:**
689
+ - `models/latest/` - Always the newest model
690
+ - `models/v1.0.0/` - Specific version you can roll back to
691
+ - `metadata.json` - Tracks all versions with metrics
692
+
693
+ **ZeroGPU:**
694
+ - Training uses GPU for 120 seconds (free tier)
695
+ - Inference uses GPU for 5 seconds per request
696
+ - All model storage handled via HF Model Repo
697
+ """)
698
+
699
+ # Tab 2: JSON API
700
+ with gr.Tab("๐Ÿ”Œ JSON API"):
701
+ gr.Markdown("### Predict ATAR (JSON API)")
702
+
703
+ with gr.Row():
704
+ load_version_input = gr.Textbox(
705
+ label="Model Version to Load (optional)",
706
+ placeholder="latest (default), v1.0.0, v20241007_143022, etc.",
707
+ value="latest"
708
+ )
709
+ load_btn = gr.Button("๐Ÿ“ฅ Load Model", variant="secondary")
710
+
711
+ load_status = gr.Textbox(label="Load Status", lines=3)
712
+
713
+ def load_model_interface(version):
714
+ return download_from_hf(version=version)
715
+
716
+ load_btn.click(
717
+ fn=load_model_interface,
718
+ inputs=load_version_input,
719
+ outputs=load_status
720
+ )
721
+
722
+ gr.Markdown("---")
723
+
724
+ subjects_input = gr.Code(
725
+ label="Input: Subjects JSON",
726
+ language="json",
727
+ value=json.dumps([
728
+ {"subject_name": "Mathematics Extension 2", "raw_mark": 88.5, "trend": "improving", "assessment_count": 4},
729
+ {"subject_name": "Physics", "raw_mark": 85.0, "trend": "stable", "assessment_count": 5},
730
+ {"subject_name": "Chemistry", "raw_mark": 84.0, "trend": "stable", "assessment_count": 5},
731
+ {"subject_name": "English Advanced", "raw_mark": 82.0, "trend": "improving", "assessment_count": 4},
732
+ {"subject_name": "Software Design & Development", "raw_mark": 86.0, "trend": "improving", "assessment_count": 3}
733
+ ], indent=2)
734
+ )
735
+
736
+ predict_btn = gr.Button("๐Ÿ”ฎ Predict ATAR", variant="primary")
737
+ prediction_output = gr.Code(label="Output: Prediction JSON", language="json")
738
+
739
+ predict_btn.click(
740
+ fn=predict_interface,
741
+ inputs=subjects_input,
742
+ outputs=prediction_output
743
+ )
744
+
745
+ gr.Markdown("""
746
+ **Note:**
747
+ - Model auto-loads `latest` version on first API call if not manually loaded
748
+ - Manually load a specific version to test different models
749
+ - All versions are preserved in HF Model Repo
750
+ - **Public repos**: No token needed for downloads
751
+ - **Private repos**: Set `HF_TOKEN` environment variable in Space settings
752
+ """)
753
+
754
+ # Tab 3: Simple Calculator
755
+ with gr.Tab("๐Ÿ“ Simple Calculator"):
756
+ gr.Markdown("### Quick ATAR Estimate")
757
+
758
+ with gr.Row():
759
+ with gr.Column():
760
+ subj1 = gr.Dropdown(choices=list(SUBJECT_SCALING_DATA.keys()), label="Subject 1")
761
+ mark1 = gr.Slider(0, 100, 85, label="Mark")
762
+ with gr.Column():
763
+ subj2 = gr.Dropdown(choices=list(SUBJECT_SCALING_DATA.keys()), label="Subject 2")
764
+ mark2 = gr.Slider(0, 100, 85, label="Mark")
765
+
766
+ with gr.Row():
767
+ with gr.Column():
768
+ subj3 = gr.Dropdown(choices=list(SUBJECT_SCALING_DATA.keys()), label="Subject 3")
769
+ mark3 = gr.Slider(0, 100, 85, label="Mark")
770
+ with gr.Column():
771
+ subj4 = gr.Dropdown(choices=list(SUBJECT_SCALING_DATA.keys()), label="Subject 4")
772
+ mark4 = gr.Slider(0, 100, 85, label="Mark")
773
+
774
+ calc_btn = gr.Button("Calculate ATAR", variant="primary")
775
+ calc_output = gr.Textbox(label="Result", lines=8)
776
+
777
+ def simple_calc(s1, m1, s2, m2, s3, m3, s4, m4):
778
+ subjects = []
779
+ for s, m in [(s1, m1), (s2, m2), (s3, m3), (s4, m4)]:
780
+ if s:
781
+ subjects.append({"subject_name": s, "raw_mark": m, "trend": "stable", "assessment_count": 3})
782
+
783
+ if not subjects:
784
+ return "โš ๏ธ Please select at least one subject"
785
+
786
+ result = predict_atar(subjects)
787
+
788
+ if 'error' in result:
789
+ return f"โŒ {result['error']}"
790
+
791
+ output = f"๐ŸŽฏ Predicted ATAR: {result['predicted_atar']}\n"
792
+ output += f"๐Ÿ“Š Confidence: {result['confidence']*100:.0f}%\n\n"
793
+ output += "๐Ÿ’ก Insights:\n" + "\n".join(result['insights'])
794
+ return output
795
+
796
+ calc_btn.click(
797
+ fn=simple_calc,
798
+ inputs=[subj1, mark1, subj2, mark2, subj3, mark3, subj4, mark4],
799
+ outputs=calc_output
800
+ )
801
+
802
+ # Tab 4: Scaling Reference
803
+ with gr.Tab("๐Ÿ“Š Scaling Reference"):
804
+ gr.Markdown("### 2024 HSC Subject Scaling Data")
805
+
806
+ scaling_df = pd.DataFrame([
807
+ {
808
+ 'Subject': name,
809
+ 'Scaling Factor': f"{data['scaling_factor']:.4f}",
810
+ 'Mean Mark': data['mean'],
811
+ 'Difficulty': data['difficulty']
812
+ }
813
+ for name, data in sorted(SUBJECT_SCALING_DATA.items(),
814
+ key=lambda x: x[1]['scaling_factor'],
815
+ reverse=True)
816
+ ])
817
+
818
+ gr.Dataframe(scaling_df, label="Subject Scaling Factors (sorted by scaling)")
819
+
820
+ # ============================================
821
+ # LAUNCH
822
+ # ============================================
823
+
824
+ if __name__ == "__main__":
825
+ app.launch(share=True, server_name="0.0.0.0", server_port=7860)