Spaces:
Sleeping
Sleeping
added option to choose tracks
Browse files- src/curriculum_analyzer.py +82 -32
- src/curriculum_optimizer.py +657 -513
- src/neu_graph_clean10.pkl +3 -0
- src/ui.py +621 -0
src/curriculum_analyzer.py
CHANGED
|
@@ -1,11 +1,19 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
import pickle
|
| 5 |
import argparse
|
| 6 |
import networkx as nx
|
| 7 |
import re
|
| 8 |
from typing import Set, Dict
|
|
|
|
| 9 |
|
| 10 |
def get_course_level(cid):
|
| 11 |
"""Extracts the numerical part of a course ID for level checking."""
|
|
@@ -13,13 +21,27 @@ def get_course_level(cid):
|
|
| 13 |
return int(match.group(0)) if match else 9999
|
| 14 |
|
| 15 |
class CurriculumAnalyzer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def __init__(self, graph_path, courses_path):
|
| 17 |
print("📚 Loading raw curriculum data...")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Merge course data into graph nodes
|
| 24 |
for course_id, course_data in self.courses.items():
|
| 25 |
if self.graph.has_node(course_id):
|
|
@@ -28,44 +50,65 @@ class CurriculumAnalyzer:
|
|
| 28 |
print(f"✅ Loaded {self.graph.number_of_nodes()} courses, {self.graph.number_of_edges()} edges")
|
| 29 |
|
| 30 |
def pre_filter_graph(self):
|
| 31 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 32 |
print("\n🧹 Pre-filtering graph...")
|
| 33 |
|
| 34 |
-
KEEP_SUBJECTS = {"CS", "DS", "IS", "CY", "MATH", "PHYS", "ENGW", "STAT", "EECE"}
|
| 35 |
-
|
| 36 |
nodes_to_remove = set()
|
| 37 |
for node, data in self.graph.nodes(data=True):
|
| 38 |
subject = data.get('subject', '')
|
| 39 |
name = data.get('name', '').lower()
|
| 40 |
level = get_course_level(node)
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
nodes_to_remove.add(node)
|
| 50 |
|
|
|
|
| 51 |
self.graph.remove_nodes_from(nodes_to_remove)
|
| 52 |
-
|
| 53 |
-
print(f"
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def fix_chains(self):
|
| 56 |
"""Adds critical prerequisite chains that might be missing."""
|
| 57 |
print("\n🔗 Validating and fixing critical prerequisite chains...")
|
| 58 |
|
| 59 |
critical_chains = {
|
|
|
|
| 60 |
("CS1800", "CS2800", "Discrete → Logic"),
|
| 61 |
("CS2500", "CS2510", "Fundies 1 → Fundies 2"),
|
|
|
|
| 62 |
("CS2510", "CS3500", "Fundies 2 → OOD"),
|
| 63 |
("CS2510", "CS3000", "Fundies 2 → Algorithms"),
|
| 64 |
-
("
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
("DS2000", "DS2500", "Prog w/ Data → Intermediate"),
|
| 67 |
("DS2500", "DS3500", "Intermediate → Advanced"),
|
| 68 |
-
("DS3500", "DS4400", "Advanced → ML1"),
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
added = 0
|
|
@@ -80,11 +123,18 @@ class CurriculumAnalyzer:
|
|
| 80 |
print(" ✅ All critical chains present")
|
| 81 |
|
| 82 |
def remove_spurious_chains(self):
|
| 83 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 84 |
print("\n🗑️ Removing spurious prerequisite chains...")
|
| 85 |
|
|
|
|
| 86 |
spurious_chains = {
|
| 87 |
-
("
|
|
|
|
|
|
|
|
|
|
| 88 |
}
|
| 89 |
|
| 90 |
removed = 0
|
|
@@ -102,6 +152,7 @@ class CurriculumAnalyzer:
|
|
| 102 |
print("\n🧮 Calculating complexity scores...")
|
| 103 |
|
| 104 |
for node in self.graph.nodes():
|
|
|
|
| 105 |
in_degree = self.graph.in_degree(node)
|
| 106 |
out_degree = self.graph.out_degree(node)
|
| 107 |
|
|
@@ -115,13 +166,14 @@ class CurriculumAnalyzer:
|
|
| 115 |
"""Check if all critical courses exist in the graph."""
|
| 116 |
print("\n🎯 Validating critical course coverage...")
|
| 117 |
|
|
|
|
| 118 |
required_courses = {
|
| 119 |
"foundations": {"CS1800", "CS2500", "CS2510", "CS2800"},
|
| 120 |
-
"core": {"CS3000", "CS3500", "CS3650", "
|
| 121 |
"ai_ml": {"CS4100", "DS4400", "CS4120", "DS4420", "CS4180", "DS4440"},
|
| 122 |
-
"systems": {"CS4730", "CS4400", "CS4500"},
|
| 123 |
"security": {"CY2550", "CY3740", "CY4740", "CY4760"},
|
| 124 |
-
"math": {"MATH1341", "MATH1342", "MATH2331", "MATH3081"},
|
| 125 |
}
|
| 126 |
|
| 127 |
missing = {}
|
|
@@ -150,14 +202,12 @@ class CurriculumAnalyzer:
|
|
| 150 |
f.write(f"Total courses: {self.graph.number_of_nodes()}\n")
|
| 151 |
f.write(f"Total prerequisites: {self.graph.number_of_edges()}\n\n")
|
| 152 |
|
| 153 |
-
# Subject breakdown
|
| 154 |
-
from collections import defaultdict
|
| 155 |
subject_counts = defaultdict(int)
|
| 156 |
for node in self.graph.nodes():
|
| 157 |
subject = self.graph.nodes[node].get('subject', 'UNKNOWN')
|
| 158 |
subject_counts[subject] += 1
|
| 159 |
|
| 160 |
-
f.write("Subject breakdown:\n")
|
| 161 |
for subject in sorted(subject_counts.keys()):
|
| 162 |
f.write(f" {subject}: {subject_counts[subject]}\n")
|
| 163 |
|
|
@@ -183,8 +233,8 @@ def main(args):
|
|
| 183 |
|
| 184 |
if __name__ == "__main__":
|
| 185 |
parser = argparse.ArgumentParser(description="NEU Curriculum Analyzer - Cleans and validates data")
|
| 186 |
-
parser.add_argument('--graph', required=True, help="Path to RAW curriculum graph")
|
| 187 |
-
parser.add_argument('--courses', required=True, help="Path to RAW courses data")
|
| 188 |
-
parser.add_argument('--output-graph', default='neu_graph_clean.pkl', help="Output path")
|
| 189 |
args = parser.parse_args()
|
| 190 |
main(args)
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
FIXED Curriculum Analyzer - Production Version
|
| 4 |
+
|
| 5 |
+
Synchronized with optimizer logic:
|
| 6 |
+
1. Filters subjects to ONLY: CS, DS, CY, MATH, PHYS, ENGW.
|
| 7 |
+
2. Removes IS, EECE, STAT, and other irrelevant subjects.
|
| 8 |
+
3. ADDS exception for undergrad-accessible 5000-level courses (CS5700).
|
| 9 |
+
4. FIXES bad prerequisite data (e.g., CS2500 -> CS2800).
|
| 10 |
"""
|
| 11 |
import pickle
|
| 12 |
import argparse
|
| 13 |
import networkx as nx
|
| 14 |
import re
|
| 15 |
from typing import Set, Dict
|
| 16 |
+
from collections import defaultdict
|
| 17 |
|
| 18 |
def get_course_level(cid):
|
| 19 |
"""Extracts the numerical part of a course ID for level checking."""
|
|
|
|
| 21 |
return int(match.group(0)) if match else 9999
|
| 22 |
|
| 23 |
class CurriculumAnalyzer:
|
| 24 |
+
|
| 25 |
+
# --- FIX 1: DEFINE LISTS THAT MATCH THE OPTIMIZER ---
|
| 26 |
+
|
| 27 |
+
# Subjects the optimizer is programmed to understand.
|
| 28 |
+
# ENGW/PHYS are needed only for hardcoded Year 1.
|
| 29 |
+
KEEP_SUBJECTS = {"CS", "DS", "CY", "MATH", "PHYS", "ENGW"}
|
| 30 |
+
|
| 31 |
+
# 5000-level courses the optimizer explicitly allows.
|
| 32 |
+
UNDERGRAD_ACCESSIBLE_GRAD = {"CS5700", "CY5700", "DS5110", "CS5010"}
|
| 33 |
+
|
| 34 |
def __init__(self, graph_path, courses_path):
|
| 35 |
print("📚 Loading raw curriculum data...")
|
| 36 |
+
try:
|
| 37 |
+
with open(graph_path, 'rb') as f:
|
| 38 |
+
self.graph = pickle.load(f)
|
| 39 |
+
with open(courses_path, 'rb') as f:
|
| 40 |
+
self.courses = pickle.load(f)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"❌ ERROR: Could not load files. {e}")
|
| 43 |
+
exit(1)
|
| 44 |
+
|
| 45 |
# Merge course data into graph nodes
|
| 46 |
for course_id, course_data in self.courses.items():
|
| 47 |
if self.graph.has_node(course_id):
|
|
|
|
| 50 |
print(f"✅ Loaded {self.graph.number_of_nodes()} courses, {self.graph.number_of_edges()} edges")
|
| 51 |
|
| 52 |
def pre_filter_graph(self):
|
| 53 |
+
"""
|
| 54 |
+
--- FIX 2: IMPLEMENTS STRICT FILTERING ---
|
| 55 |
+
Keeps only relevant subjects and removes labs/high-level courses.
|
| 56 |
+
"""
|
| 57 |
print("\n🧹 Pre-filtering graph...")
|
| 58 |
|
|
|
|
|
|
|
| 59 |
nodes_to_remove = set()
|
| 60 |
for node, data in self.graph.nodes(data=True):
|
| 61 |
subject = data.get('subject', '')
|
| 62 |
name = data.get('name', '').lower()
|
| 63 |
level = get_course_level(node)
|
| 64 |
|
| 65 |
+
# Check for removal
|
| 66 |
+
is_irrelevant_subject = subject not in self.KEEP_SUBJECTS
|
| 67 |
+
is_lab_or_seminar = any(skip in name for skip in ['lab', 'recitation', 'seminar', 'practicum', 'co-op'])
|
| 68 |
+
|
| 69 |
+
# Grad-level check
|
| 70 |
+
is_grad_level = level >= 5000
|
| 71 |
+
is_allowed_grad = node in self.UNDERGRAD_ACCESSIBLE_GRAD
|
| 72 |
+
|
| 73 |
+
if (is_irrelevant_subject or
|
| 74 |
+
is_lab_or_seminar or
|
| 75 |
+
(is_grad_level and not is_allowed_grad)): # <-- Bug fix
|
| 76 |
nodes_to_remove.add(node)
|
| 77 |
|
| 78 |
+
original_count = self.graph.number_of_nodes()
|
| 79 |
self.graph.remove_nodes_from(nodes_to_remove)
|
| 80 |
+
|
| 81 |
+
print(f"✅ Removed {len(nodes_to_remove)} irrelevant courses (IS, EECE, etc.)")
|
| 82 |
+
print(f" Original nodes: {original_count}")
|
| 83 |
+
print(f" Remaining nodes: {self.graph.number_of_nodes()}")
|
| 84 |
|
| 85 |
def fix_chains(self):
|
| 86 |
"""Adds critical prerequisite chains that might be missing."""
|
| 87 |
print("\n🔗 Validating and fixing critical prerequisite chains...")
|
| 88 |
|
| 89 |
critical_chains = {
|
| 90 |
+
# Foundations
|
| 91 |
("CS1800", "CS2800", "Discrete → Logic"),
|
| 92 |
("CS2500", "CS2510", "Fundies 1 → Fundies 2"),
|
| 93 |
+
# Core CS
|
| 94 |
("CS2510", "CS3500", "Fundies 2 → OOD"),
|
| 95 |
("CS2510", "CS3000", "Fundies 2 → Algorithms"),
|
| 96 |
+
("CS2800", "CS3000", "Logic → Algorithms"),
|
| 97 |
+
|
| 98 |
+
# --- THIS IS THE FIX ---
|
| 99 |
+
("CS3000", "CS3650", "Algorithms -> Systems"),
|
| 100 |
+
# ---------------------
|
| 101 |
+
|
| 102 |
+
# Core AI/ML
|
| 103 |
+
("CS3000", "CS4100", "Algorithms → AI"),
|
| 104 |
+
("CS3500", "CS4100", "OOD → AI"),
|
| 105 |
+
# Core DS Path
|
| 106 |
("DS2000", "DS2500", "Prog w/ Data → Intermediate"),
|
| 107 |
("DS2500", "DS3500", "Intermediate → Advanced"),
|
| 108 |
+
("DS3500", "DS4400", "Advanced → ML1"),
|
| 109 |
+
("CS3500", "DS4400", "OOD → ML1"),
|
| 110 |
+
# Math
|
| 111 |
+
("MATH1341", "MATH1342", "Calc 1 → Calc 2"),
|
| 112 |
}
|
| 113 |
|
| 114 |
added = 0
|
|
|
|
| 123 |
print(" ✅ All critical chains present")
|
| 124 |
|
| 125 |
def remove_spurious_chains(self):
|
| 126 |
+
"""
|
| 127 |
+
--- FIX 3: REMOVE BAD DATA ---
|
| 128 |
+
Removes known incorrect prerequisite edges from scraper.
|
| 129 |
+
"""
|
| 130 |
print("\n🗑️ Removing spurious prerequisite chains...")
|
| 131 |
|
| 132 |
+
# Based on your inspect_graph output and catalog knowledge
|
| 133 |
spurious_chains = {
|
| 134 |
+
("CS2500", "CS2800"), # Fundies 1 is NOT a prereq for Logic
|
| 135 |
+
("MATH1365", "CS2800"), # Not a real prereq
|
| 136 |
+
("EECE2160", "CS3000"), # Irrelevant prereq
|
| 137 |
+
("EECE2560", "CS3500"), # Irrelevant prereq
|
| 138 |
}
|
| 139 |
|
| 140 |
removed = 0
|
|
|
|
| 152 |
print("\n🧮 Calculating complexity scores...")
|
| 153 |
|
| 154 |
for node in self.graph.nodes():
|
| 155 |
+
# Use predecessors/successors on the *cleaned* graph
|
| 156 |
in_degree = self.graph.in_degree(node)
|
| 157 |
out_degree = self.graph.out_degree(node)
|
| 158 |
|
|
|
|
| 166 |
"""Check if all critical courses exist in the graph."""
|
| 167 |
print("\n🎯 Validating critical course coverage...")
|
| 168 |
|
| 169 |
+
# This list MUST match the optimizer's requirements
|
| 170 |
required_courses = {
|
| 171 |
"foundations": {"CS1800", "CS2500", "CS2510", "CS2800"},
|
| 172 |
+
"core": {"CS3000", "CS3500", "CS3650", "CS3200", "CS5700"}, # Added CS5700
|
| 173 |
"ai_ml": {"CS4100", "DS4400", "CS4120", "DS4420", "CS4180", "DS4440"},
|
| 174 |
+
"systems": {"CS4730", "CS4700", "CS4400", "CS4500"},
|
| 175 |
"security": {"CY2550", "CY3740", "CY4740", "CY4760"},
|
| 176 |
+
"math": {"MATH1341", "MATH1342", "MATH2331", "MATH3081"},
|
| 177 |
}
|
| 178 |
|
| 179 |
missing = {}
|
|
|
|
| 202 |
f.write(f"Total courses: {self.graph.number_of_nodes()}\n")
|
| 203 |
f.write(f"Total prerequisites: {self.graph.number_of_edges()}\n\n")
|
| 204 |
|
|
|
|
|
|
|
| 205 |
subject_counts = defaultdict(int)
|
| 206 |
for node in self.graph.nodes():
|
| 207 |
subject = self.graph.nodes[node].get('subject', 'UNKNOWN')
|
| 208 |
subject_counts[subject] += 1
|
| 209 |
|
| 210 |
+
f.write("Subject breakdown (Filtered):\n")
|
| 211 |
for subject in sorted(subject_counts.keys()):
|
| 212 |
f.write(f" {subject}: {subject_counts[subject]}\n")
|
| 213 |
|
|
|
|
| 233 |
|
| 234 |
if __name__ == "__main__":
|
| 235 |
parser = argparse.ArgumentParser(description="NEU Curriculum Analyzer - Cleans and validates data")
|
| 236 |
+
parser.add_argument('--graph', required=True, help="Path to RAW curriculum graph (e.g., neu_merged_graph_...pkl)")
|
| 237 |
+
parser.add_argument('--courses', required=True, help="Path to RAW courses data (e.g., neu_merged_courses_...pkl)")
|
| 238 |
+
parser.add_argument('--output-graph', default='neu_graph_clean.pkl', help="Output path for the new, clean graph")
|
| 239 |
args = parser.parse_args()
|
| 240 |
main(args)
|
src/curriculum_optimizer.py
CHANGED
|
@@ -1,513 +1,657 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
"""
|
| 5 |
-
import torch
|
| 6 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 7 |
-
from sentence_transformers import SentenceTransformer, util
|
| 8 |
-
import networkx as nx
|
| 9 |
-
import numpy as np
|
| 10 |
-
from typing import Dict, List, Set,
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
import re
|
| 13 |
-
import
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
"
|
| 53 |
-
"required": ["
|
| 54 |
-
"
|
| 55 |
-
},
|
| 56 |
-
"
|
| 57 |
-
"required": ["
|
| 58 |
-
"
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
"
|
| 63 |
-
"
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
"
|
| 74 |
-
"
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
self.
|
| 86 |
-
self.
|
| 87 |
-
self.
|
| 88 |
-
self.
|
| 89 |
-
self.
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
self.
|
| 94 |
-
|
| 95 |
-
def
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
if
|
| 133 |
-
continue
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
score = 0.0
|
| 188 |
-
|
| 189 |
-
#
|
| 190 |
-
if
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
#
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
#
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
#
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
if
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Curriculum Optimizer - PRODUCTION VERSION
|
| 3 |
+
All redundant code removed, all critical issues fixed
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 7 |
+
from sentence_transformers import SentenceTransformer, util
|
| 8 |
+
import networkx as nx
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Dict, List, Set, Optional
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import re
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class StudentProfile:
|
| 17 |
+
completed_courses: List[str]
|
| 18 |
+
time_commitment: int
|
| 19 |
+
preferred_difficulty: str
|
| 20 |
+
career_goals: str
|
| 21 |
+
interests: List[str]
|
| 22 |
+
current_gpa: float = 3.5
|
| 23 |
+
learning_style: str = "Visual"
|
| 24 |
+
|
| 25 |
+
class HybridOptimizer:
|
| 26 |
+
|
| 27 |
+
EQUIVALENCY_GROUPS = [
|
| 28 |
+
{"MATH1341", "MATH1241", "MATH1231"},
|
| 29 |
+
{"MATH1342", "MATH1242"},
|
| 30 |
+
{"PHYS1151", "PHYS1161", "PHYS1145"},
|
| 31 |
+
{"PHYS1155", "PHYS1165", "PHYS1147"},
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
COURSE_TRACKS = {
|
| 35 |
+
"physics": {
|
| 36 |
+
"engineering": ["PHYS1151", "PHYS1155"],
|
| 37 |
+
"science": ["PHYS1161", "PHYS1165"],
|
| 38 |
+
"life_sciences": ["PHYS1145", "PHYS1147"]
|
| 39 |
+
},
|
| 40 |
+
"calculus": {
|
| 41 |
+
"standard": ["MATH1341", "MATH1342"],
|
| 42 |
+
"computational": ["MATH156", "MATH256"]
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
CONCENTRATION_REQUIREMENTS = {
|
| 47 |
+
"ai_ml": {
|
| 48 |
+
"foundations": {
|
| 49 |
+
"required": ["CS1800", "CS2500", "CS2510", "CS2800"],
|
| 50 |
+
"sequence": True
|
| 51 |
+
},
|
| 52 |
+
"core": {
|
| 53 |
+
"required": ["CS3000", "CS3500"],
|
| 54 |
+
"pick_1_from": ["CS3200", "CS3650", "CS5700"]
|
| 55 |
+
},
|
| 56 |
+
"concentration_specific": {
|
| 57 |
+
"required": ["CS4100", "DS4400"],
|
| 58 |
+
"pick_2_from": ["CS4120", "CS4180", "DS4420", "DS4440"],
|
| 59 |
+
"pick_1_systems": ["CS4730", "CS4700"]
|
| 60 |
+
},
|
| 61 |
+
"math": {
|
| 62 |
+
"required": ["MATH1341", "MATH1342"],
|
| 63 |
+
"pick_1_from": ["MATH2331", "MATH3081"]
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"systems": {
|
| 67 |
+
"foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
|
| 68 |
+
"core": {"required": ["CS3000", "CS3500", "CS3650"], "pick_1_from": ["CS5700", "CS3200"]},
|
| 69 |
+
"concentration_specific": {"required": ["CS4700"], "pick_2_from": ["CS4730"], "pick_1_from": ["CS4400", "CS4500", "CS4520"]},
|
| 70 |
+
"math": {"required": ["MATH1341", "MATH1342"]}
|
| 71 |
+
},
|
| 72 |
+
"security": {
|
| 73 |
+
"foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
|
| 74 |
+
"core": {"required": ["CS3000", "CS3650", "CY2550"], "pick_1_from": ["CS5700", "CS3500"]},
|
| 75 |
+
"concentration_specific": {"required": ["CY3740"], "pick_2_from": ["CY4740", "CY4760", "CY4770"], "pick_1_from": ["CS4700", "CS4730"]},
|
| 76 |
+
"math": {"required": ["MATH1342"], "pick_1_from": ["MATH3527", "MATH3081"]}
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def __init__(self):
|
| 81 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
+
self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 83 |
+
self.embedding_model_name = 'BAAI/bge-large-en-v1.5'
|
| 84 |
+
self.llm = None
|
| 85 |
+
self.tokenizer = None
|
| 86 |
+
self.embedding_model = None
|
| 87 |
+
self.curriculum_graph = None
|
| 88 |
+
self.courses = {}
|
| 89 |
+
self.current_student = None
|
| 90 |
+
|
| 91 |
+
def load_models(self):
|
| 92 |
+
print("Loading embedding model...")
|
| 93 |
+
self.embedding_model = SentenceTransformer(self.embedding_model_name, device=self.device)
|
| 94 |
+
|
| 95 |
+
def load_llm(self):
|
| 96 |
+
if self.device.type == 'cuda' and self.llm is None:
|
| 97 |
+
print("Loading LLM for intelligent planning...")
|
| 98 |
+
quant_config = BitsAndBytesConfig(
|
| 99 |
+
load_in_4bit=True,
|
| 100 |
+
bnb_4bit_quant_type="nf4",
|
| 101 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
| 102 |
+
)
|
| 103 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 104 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 105 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
| 106 |
+
self.model_name,
|
| 107 |
+
quantization_config=quant_config,
|
| 108 |
+
device_map="auto"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def load_data(self, graph: nx.DiGraph):
|
| 112 |
+
self.curriculum_graph = graph
|
| 113 |
+
self.courses = dict(graph.nodes(data=True))
|
| 114 |
+
UNDERGRAD_ACCESSIBLE_GRAD = {"CS5700", "CY5700", "DS5110", "CS5010"}
|
| 115 |
+
self.valid_courses = []
|
| 116 |
+
course_texts = []
|
| 117 |
+
|
| 118 |
+
concentration_courses = set()
|
| 119 |
+
for track_reqs in self.CONCENTRATION_REQUIREMENTS.values():
|
| 120 |
+
for category, reqs in track_reqs.items():
|
| 121 |
+
if isinstance(reqs, dict):
|
| 122 |
+
for key, courses in reqs.items():
|
| 123 |
+
if isinstance(courses, list):
|
| 124 |
+
concentration_courses.update(courses)
|
| 125 |
+
|
| 126 |
+
for cid, data in self.courses.items():
|
| 127 |
+
name = data.get('name', '')
|
| 128 |
+
if not name or name.strip() == '' or any(skip in name.lower() for skip in ['lab', 'recitation', 'seminar', 'practicum']):
|
| 129 |
+
continue
|
| 130 |
+
|
| 131 |
+
course_level = self._get_level(cid)
|
| 132 |
+
if course_level >= 5000 and cid not in UNDERGRAD_ACCESSIBLE_GRAD:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
self.valid_courses.append(cid)
|
| 136 |
+
course_texts.append(f"{name} {data.get('description', '')}")
|
| 137 |
+
|
| 138 |
+
missing_required = concentration_courses - set(self.valid_courses)
|
| 139 |
+
if missing_required:
|
| 140 |
+
print(f"\n⚠️ WARNING: {len(missing_required)} required courses missing from graph: {sorted(missing_required)}\n")
|
| 141 |
+
|
| 142 |
+
print(f"Computing embeddings for {len(self.valid_courses)} courses...")
|
| 143 |
+
self.course_embeddings = self.embedding_model.encode(course_texts, convert_to_tensor=True, show_progress_bar=True)
|
| 144 |
+
print(f"\nTotal valid courses: {len(self.valid_courses)}")
|
| 145 |
+
|
| 146 |
+
def _get_level(self, course_id: str) -> int:
|
| 147 |
+
match = re.search(r'\d+', course_id)
|
| 148 |
+
return int(match.group()) if match else 9999
|
| 149 |
+
|
| 150 |
+
def _get_completed_with_equivalents(self, completed: Set[str]) -> Set[str]:
|
| 151 |
+
expanded_completed = completed.copy()
|
| 152 |
+
for course in completed:
|
| 153 |
+
for group in self.EQUIVALENCY_GROUPS:
|
| 154 |
+
if course in group:
|
| 155 |
+
expanded_completed.update(group)
|
| 156 |
+
return expanded_completed
|
| 157 |
+
|
| 158 |
+
def _can_take_course(self, course_id: str, completed: Set[str]) -> bool:
|
| 159 |
+
effective_completed = self._get_completed_with_equivalents(completed)
|
| 160 |
+
if course_id not in self.curriculum_graph:
|
| 161 |
+
return True
|
| 162 |
+
prereqs = set(self.curriculum_graph.predecessors(course_id))
|
| 163 |
+
return prereqs.issubset(effective_completed)
|
| 164 |
+
|
| 165 |
+
def _validate_sequence(self, selected: List[str], candidate: str) -> bool:
|
| 166 |
+
for track_type, tracks in self.COURSE_TRACKS.items():
|
| 167 |
+
for track_name, sequence in tracks.items():
|
| 168 |
+
if candidate in sequence:
|
| 169 |
+
for other_track, other_seq in tracks.items():
|
| 170 |
+
if other_track != track_name and any(c in selected for c in other_seq):
|
| 171 |
+
return False
|
| 172 |
+
return True
|
| 173 |
+
|
| 174 |
+
def _score_course(self, course_id: str, semantic_scores: Dict[str, float], required_set: Set[str], picklist_set: Set[str], year: int, track: str) -> float:
|
| 175 |
+
"""
|
| 176 |
+
PRODUCTION SCORING - NOW TRACK AWARE
|
| 177 |
+
Applies different boosts based on the selected track.
|
| 178 |
+
"""
|
| 179 |
+
if course_id not in self.courses or not self.courses[course_id].get('name', '').strip():
|
| 180 |
+
return -10000.0
|
| 181 |
+
|
| 182 |
+
course_data = self.courses[course_id]
|
| 183 |
+
subject = course_data.get('subject', '')
|
| 184 |
+
level = self._get_level(course_id)
|
| 185 |
+
name = course_data.get('name', '').lower()
|
| 186 |
+
|
| 187 |
+
score = 0.0
|
| 188 |
+
|
| 189 |
+
# --- SEMANTICS APPLIED FIRST ---
|
| 190 |
+
semantic_weight = 15.0 if year == 4 else 5.0
|
| 191 |
+
score += semantic_scores.get(course_id, 0.0) * semantic_weight
|
| 192 |
+
|
| 193 |
+
# --- PENALTY APPLIED AFTER SEMANTICS ---
|
| 194 |
+
non_technical_keywords = ['society', 'ethics', 'law', 'policy', 'mobile', 'game', 'visualiz', 'web']
|
| 195 |
+
if any(keyword in name for keyword in non_technical_keywords):
|
| 196 |
+
# Exception: allow 'game' and 'mobile' if game_dev track is selected
|
| 197 |
+
if track == "game_dev" and any(k in name for k in ['game', 'mobile']):
|
| 198 |
+
pass # Do not penalize
|
| 199 |
+
else:
|
| 200 |
+
score -= 10000.0
|
| 201 |
+
|
| 202 |
+
# Subject-aware scoring
|
| 203 |
+
if subject in ["CS", "DS"]:
|
| 204 |
+
score += 300.0
|
| 205 |
+
elif subject == "CY":
|
| 206 |
+
if level < 3000:
|
| 207 |
+
score -= 500.0
|
| 208 |
+
else:
|
| 209 |
+
score += 300.0 # Allow CY electives if not intro
|
| 210 |
+
elif subject == "MATH":
|
| 211 |
+
score += 100.0
|
| 212 |
+
else:
|
| 213 |
+
score -= 1000.0
|
| 214 |
+
|
| 215 |
+
# --- TRACK-AWARE CRITICAL PATH BOOSTS ---
|
| 216 |
+
if track == "ai_ml":
|
| 217 |
+
if course_id in ["DS2500", "DS3000", "DS3500"]:
|
| 218 |
+
score += 7000.0
|
| 219 |
+
elif track == "security":
|
| 220 |
+
if course_id in ["CY2550", "CY3740"]:
|
| 221 |
+
score += 7000.0
|
| 222 |
+
elif track == "systems":
|
| 223 |
+
if course_id == "CS3650":
|
| 224 |
+
score += 7000.0
|
| 225 |
+
elif track == "game_dev":
|
| 226 |
+
if course_id == "CS3540": # Game Programming
|
| 227 |
+
score += 8000.0 # Main course for this track
|
| 228 |
+
# "general" track gets no special boosts
|
| 229 |
+
|
| 230 |
+
# Hard requirements
|
| 231 |
+
if course_id in required_set:
|
| 232 |
+
score += 10000.0
|
| 233 |
+
|
| 234 |
+
# Pick-list courses
|
| 235 |
+
if course_id in picklist_set:
|
| 236 |
+
score += 5000.0
|
| 237 |
+
|
| 238 |
+
# Unlocking factor
|
| 239 |
+
if course_id in self.curriculum_graph:
|
| 240 |
+
unlocks = self.curriculum_graph.out_degree(course_id)
|
| 241 |
+
score += min(unlocks, 5) * 2.0
|
| 242 |
+
|
| 243 |
+
# Level preference
|
| 244 |
+
score -= (level / 100.0)
|
| 245 |
+
|
| 246 |
+
# Year-specific penalties
|
| 247 |
+
if year == 4 and level < 4000:
|
| 248 |
+
score -= 3000.0
|
| 249 |
+
elif year == 3 and level < 3000:
|
| 250 |
+
score -= 2000.0
|
| 251 |
+
|
| 252 |
+
return score
|
| 253 |
+
|
| 254 |
+
def generate_simple_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
|
| 255 |
+
print("--- Generating Enhanced Rule-Based Plan ---")
|
| 256 |
+
self.current_student = student
|
| 257 |
+
return self.generate_enhanced_rule_plan(student, track_override)
|
| 258 |
+
|
| 259 |
+
def generate_enhanced_rule_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
|
| 260 |
+
self.current_student = student
|
| 261 |
+
|
| 262 |
+
# --- FIX: Logic corrected to respect "general" override ---
|
| 263 |
+
if track_override:
|
| 264 |
+
track = track_override
|
| 265 |
+
print(f"--- Using user-selected track: {track} ---")
|
| 266 |
+
else:
|
| 267 |
+
track = self._identify_track(student)
|
| 268 |
+
print(f"--- Auto-identified track: {track} ---")
|
| 269 |
+
if not track:
|
| 270 |
+
track = "general"
|
| 271 |
+
|
| 272 |
+
plan = self._build_structured_plan(student, track, None)
|
| 273 |
+
validation = self.validate_plan(plan, student)
|
| 274 |
+
|
| 275 |
+
if validation["errors"]:
|
| 276 |
+
plan = self._fix_plan_errors(plan, validation, student)
|
| 277 |
+
validation = self.validate_plan(plan, student)
|
| 278 |
+
|
| 279 |
+
difficulty_level = self._map_difficulty(student.preferred_difficulty)
|
| 280 |
+
courses_per_semester = self._calculate_course_load(student.time_commitment)
|
| 281 |
+
|
| 282 |
+
track_name = track.replace("_", " ").title()
|
| 283 |
+
explanation = f"Personalized {track_name} track ({difficulty_level} difficulty, {courses_per_semester} courses/semester)"
|
| 284 |
+
|
| 285 |
+
return self._finalize_plan(plan, explanation, validation)
|
| 286 |
+
|
| 287 |
+
def generate_llm_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
|
| 288 |
+
print("--- Generating AI-Optimized Plan ---")
|
| 289 |
+
self.current_student = student
|
| 290 |
+
self.load_llm()
|
| 291 |
+
if not self.llm:
|
| 292 |
+
return self.generate_enhanced_rule_plan(student, track_override) # Pass override
|
| 293 |
+
|
| 294 |
+
# --- FIX: Use override if provided, otherwise identify ---
|
| 295 |
+
if track_override and track_override != "general":
|
| 296 |
+
track = track_override
|
| 297 |
+
print(f"--- Using user-selected track: {track} ---")
|
| 298 |
+
else:
|
| 299 |
+
track = self._identify_track(student)
|
| 300 |
+
print(f"--- Auto-identified track: {track} ---")
|
| 301 |
+
if not track:
|
| 302 |
+
track = "general"
|
| 303 |
+
|
| 304 |
+
llm_suggestions = self._get_llm_course_suggestions(student, track)
|
| 305 |
+
plan = self._build_structured_plan(student, track, llm_suggestions)
|
| 306 |
+
validation = self.validate_plan(plan, student)
|
| 307 |
+
if validation["errors"]:
|
| 308 |
+
plan = self._fix_plan_errors(plan, validation, student)
|
| 309 |
+
validation = self.validate_plan(plan, student)
|
| 310 |
+
|
| 311 |
+
track_name = track.replace("_", " ").title()
|
| 312 |
+
explanation = self._generate_explanation(student, plan, track, f"AI-optimized {track_name}")
|
| 313 |
+
return self._finalize_plan(plan, explanation, validation)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _build_structured_plan(self, student: StudentProfile, track: str, llm_suggestions: Optional[List[str]] = None) -> Dict:
|
| 318 |
+
"""
|
| 319 |
+
PRODUCTION PLANNER - NOW FULLY TRACK-AWARE
|
| 320 |
+
Uses different priority lists based on the selected track.
|
| 321 |
+
"""
|
| 322 |
+
completed = set(student.completed_courses)
|
| 323 |
+
plan = {}
|
| 324 |
+
|
| 325 |
+
# --- FIX: TRACK-AWARE REQUIREMENTS ---
|
| 326 |
+
if track == "general":
|
| 327 |
+
print("--- Using General CS requirements ---")
|
| 328 |
+
requirements = {
|
| 329 |
+
"foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
|
| 330 |
+
"core": {"required": ["CS3000", "CS3500", "CS3650"]},
|
| 331 |
+
"math": {"required": ["MATH1341", "MATH1342"], "pick_1_from": ["MATH2331", "MATH3081"]}
|
| 332 |
+
}
|
| 333 |
+
elif track == "game_dev":
|
| 334 |
+
print("--- Using Game Dev (AI/ML base) requirements ---")
|
| 335 |
+
# Use ai_ml as a base, scoring/priorities will handle the rest
|
| 336 |
+
requirements = self.CONCENTRATION_REQUIREMENTS["ai_ml"]
|
| 337 |
+
else:
|
| 338 |
+
requirements = self.CONCENTRATION_REQUIREMENTS.get(track, self.CONCENTRATION_REQUIREMENTS["ai_ml"])
|
| 339 |
+
|
| 340 |
+
courses_per_semester = self._calculate_course_load(student.time_commitment)
|
| 341 |
+
|
| 342 |
+
# Build required and pick sets
|
| 343 |
+
required_set = set()
|
| 344 |
+
picklist_set = set()
|
| 345 |
+
for category, reqs in requirements.items():
|
| 346 |
+
if "required" in reqs:
|
| 347 |
+
required_set.update(reqs["required"])
|
| 348 |
+
for key, courses in reqs.items():
|
| 349 |
+
if key.startswith("pick_"):
|
| 350 |
+
picklist_set.update(courses)
|
| 351 |
+
|
| 352 |
+
semantic_scores = self._compute_semantic_scores(student)
|
| 353 |
+
|
| 354 |
+
# --- FIX: TRACK-AWARE PRIORITIES ---
|
| 355 |
+
TRACK_YEAR_PRIORITIES = {
|
| 356 |
+
"general": {
|
| 357 |
+
2: ["CS3000", "CS3500", "CS3650", "MATH2331", "MATH3081", "CS3200"],
|
| 358 |
+
3: ["CS4700", "CS4400", "CS4500", "CS4100"],
|
| 359 |
+
4: ["CS5700", "CS4730", "CS4530", "CS4550", "CS4410"]
|
| 360 |
+
},
|
| 361 |
+
"ai_ml": {
|
| 362 |
+
2: ["CS3000", "CS3500", "DS2500", "DS3000", "DS3500", "MATH2331", "MATH3081", "CS3650"],
|
| 363 |
+
3: ["CS4100", "DS4400", "CS4120", "DS4420", "DS4440", "CS4180"],
|
| 364 |
+
4: ["CS4730", "CS4700", "CS5700", "DS4300", "CS4400", "CS4500"]
|
| 365 |
+
},
|
| 366 |
+
"security": {
|
| 367 |
+
2: ["CS3000", "CS3650", "CY2550", "MATH2331", "MATH3081", "CS3500"],
|
| 368 |
+
3: ["CY3740", "CS4700", "CS5700", "CS4730"],
|
| 369 |
+
4: ["CY4740", "CY4760", "CS4400"] # CY4770 is missing from graph
|
| 370 |
+
},
|
| 371 |
+
"systems": {
|
| 372 |
+
2: ["CS3000", "CS3500", "CS3650", "MATH2331", "CS3200"],
|
| 373 |
+
3: ["CS4700", "CS5700", "CS4730", "CS4500", "CS4400"],
|
| 374 |
+
4: ["CS4520", "CS4410"]
|
| 375 |
+
},
|
| 376 |
+
"game_dev": {
|
| 377 |
+
2: ["CS3000", "CS3500", "CS3540", "MATH2331", "MATH3081", "CS3650"],
|
| 378 |
+
3: ["CS4520", "CS4300", "CS4100", "CS4700"],
|
| 379 |
+
4: ["CS4550", "CS4410", "CS4180"]
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
for sem_num in range(1, 9):
|
| 384 |
+
year = ((sem_num - 1) // 2) + 1
|
| 385 |
+
|
| 386 |
+
available_courses = self._get_available_courses(completed, year, sem_num, track)
|
| 387 |
+
|
| 388 |
+
schedulable = [
|
| 389 |
+
c for c in available_courses
|
| 390 |
+
if c not in completed and self._can_take_course(c, completed)
|
| 391 |
+
]
|
| 392 |
+
|
| 393 |
+
# Use track-specific priorities, default to "general" if track is unknown
|
| 394 |
+
current_year_priorities = TRACK_YEAR_PRIORITIES.get(track, TRACK_YEAR_PRIORITIES["general"]).get(year)
|
| 395 |
+
|
| 396 |
+
if current_year_priorities:
|
| 397 |
+
priority_courses = [c for c in current_year_priorities if c in schedulable]
|
| 398 |
+
other_courses = [c for c in schedulable if c not in current_year_priorities]
|
| 399 |
+
|
| 400 |
+
scored_priority = sorted(
|
| 401 |
+
priority_courses,
|
| 402 |
+
# --- FIX: Pass 'track' to score_course ---
|
| 403 |
+
key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set, year, track),
|
| 404 |
+
reverse=True
|
| 405 |
+
)
|
| 406 |
+
scored_others = sorted(
|
| 407 |
+
other_courses,
|
| 408 |
+
key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set, year, track),
|
| 409 |
+
reverse=True
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
scored_courses = scored_priority + scored_others
|
| 413 |
+
else:
|
| 414 |
+
# Year 1: normal scoring
|
| 415 |
+
scored_courses = sorted(
|
| 416 |
+
schedulable,
|
| 417 |
+
key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set, year, track),
|
| 418 |
+
reverse=True
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Select top N courses
|
| 422 |
+
selected = []
|
| 423 |
+
for course in scored_courses:
|
| 424 |
+
if len(selected) >= courses_per_semester:
|
| 425 |
+
break
|
| 426 |
+
if self._validate_sequence(selected, course):
|
| 427 |
+
selected.append(course)
|
| 428 |
+
|
| 429 |
+
if selected:
|
| 430 |
+
year_key = f"year_{year}"
|
| 431 |
+
if year_key not in plan:
|
| 432 |
+
plan[year_key] = {}
|
| 433 |
+
|
| 434 |
+
sem_type = 'fall' if (sem_num % 2) == 1 else 'spring'
|
| 435 |
+
plan[year_key][sem_type] = selected
|
| 436 |
+
completed.update(selected)
|
| 437 |
+
|
| 438 |
+
return plan
|
| 439 |
+
|
| 440 |
+
def _get_available_courses(self, completed: Set[str], year: int, sem_num: int = None, track: str = "ai_ml") -> List[str]:
|
| 441 |
+
"""
|
| 442 |
+
PRODUCTION COURSE FILTER - Strict level enforcement
|
| 443 |
+
"""
|
| 444 |
+
# Year 1: Hardcoded foundation
|
| 445 |
+
if year == 1:
|
| 446 |
+
if not completed or len(completed) < 2:
|
| 447 |
+
return [c for c in ["CS1800", "CS2500", "MATH1341", "ENGW1111"] if c in self.valid_courses]
|
| 448 |
+
else:
|
| 449 |
+
next_courses = []
|
| 450 |
+
prereq_map = [
|
| 451 |
+
("CS2800", "CS1800"),
|
| 452 |
+
("CS2510", "CS2500"),
|
| 453 |
+
("MATH1342", "MATH1341"),
|
| 454 |
+
("DS2000", None),
|
| 455 |
+
("DS2500", "DS2000")
|
| 456 |
+
]
|
| 457 |
+
|
| 458 |
+
for course, prereq in prereq_map:
|
| 459 |
+
if course in self.valid_courses and course not in completed:
|
| 460 |
+
if prereq is None or prereq in completed:
|
| 461 |
+
next_courses.append(course)
|
| 462 |
+
return next_courses
|
| 463 |
+
|
| 464 |
+
# Years 2-4: Strict filtering by subject and level
|
| 465 |
+
available = []
|
| 466 |
+
ALLOWED_SUBJECTS = {"CS", "DS", "CY", "MATH"}
|
| 467 |
+
|
| 468 |
+
for cid in self.valid_courses:
|
| 469 |
+
if cid in completed:
|
| 470 |
+
continue
|
| 471 |
+
|
| 472 |
+
course_data = self.courses.get(cid, {})
|
| 473 |
+
subject = course_data.get('subject')
|
| 474 |
+
|
| 475 |
+
if subject not in ALLOWED_SUBJECTS:
|
| 476 |
+
continue
|
| 477 |
+
|
| 478 |
+
course_level = self._get_level(cid)
|
| 479 |
+
|
| 480 |
+
# FIX: Strict year-based level filtering
|
| 481 |
+
if year == 2:
|
| 482 |
+
if course_level < 2000 or course_level > 3999:
|
| 483 |
+
continue # Year 2: only 2000-3999
|
| 484 |
+
elif year == 3:
|
| 485 |
+
if course_level < 3000:
|
| 486 |
+
continue # Year 3: 3000+ only
|
| 487 |
+
elif year == 4:
|
| 488 |
+
if course_level < 4000:
|
| 489 |
+
continue # Year 4: 4000+ only (including CS5700)
|
| 490 |
+
|
| 491 |
+
available.append(cid)
|
| 492 |
+
|
| 493 |
+
return available
|
| 494 |
+
|
| 495 |
+
def _fix_plan_errors(self, plan: Dict, validation: Dict, student: StudentProfile) -> Dict:
|
| 496 |
+
if any("Mixed" in error for error in validation["errors"]):
|
| 497 |
+
return self._build_structured_plan(student, self._identify_track(student), None)
|
| 498 |
+
return plan
|
| 499 |
+
|
| 500 |
+
def _get_llm_course_suggestions(self, student: StudentProfile, track: str) -> List[str]:
|
| 501 |
+
requirements = self.CONCENTRATION_REQUIREMENTS.get(track, {})
|
| 502 |
+
all_options = set()
|
| 503 |
+
for reqs in requirements.values():
|
| 504 |
+
for key, courses in reqs.items():
|
| 505 |
+
if key.startswith("pick_"):
|
| 506 |
+
all_options.update(courses)
|
| 507 |
+
|
| 508 |
+
course_options_text = [
|
| 509 |
+
f"{cid}: {self.courses[cid].get('name', cid)} - {self.courses[cid].get('description', '')[:100].strip()}"
|
| 510 |
+
for cid in list(all_options)[:15] if cid in self.courses
|
| 511 |
+
]
|
| 512 |
+
|
| 513 |
+
prompt = f"""Expert curriculum advisor ranking courses for student.
|
| 514 |
+
|
| 515 |
+
Student Profile:
|
| 516 |
+
- Career Goal: {student.career_goals}
|
| 517 |
+
- Interests: {', '.join(student.interests)}
|
| 518 |
+
- Difficulty: {student.preferred_difficulty}
|
| 519 |
+
|
| 520 |
+
Available Courses:
|
| 521 |
+
{chr(10).join(course_options_text)}
|
| 522 |
+
|
| 523 |
+
Return ONLY top 5 course IDs, one per line."""
|
| 524 |
+
|
| 525 |
+
try:
|
| 526 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(self.device)
|
| 527 |
+
with torch.no_grad():
|
| 528 |
+
outputs = self.llm.generate(
|
| 529 |
+
**inputs,
|
| 530 |
+
max_new_tokens=100,
|
| 531 |
+
temperature=0.2,
|
| 532 |
+
do_sample=True,
|
| 533 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 534 |
+
)
|
| 535 |
+
response = self.tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
|
| 536 |
+
suggested_courses = re.findall(r'([A-Z]{2,4}\d{4})', response)
|
| 537 |
+
return suggested_courses[:5]
|
| 538 |
+
except Exception as e:
|
| 539 |
+
print(f"LLM suggestion failed: {e}")
|
| 540 |
+
return list(all_options)[:5]
|
| 541 |
+
|
| 542 |
+
def _map_difficulty(self, preferred_difficulty: str) -> str:
|
| 543 |
+
return {"easy": "easy", "moderate": "medium", "challenging": "hard"}.get(preferred_difficulty.lower(), "medium")
|
| 544 |
+
|
| 545 |
+
def _calculate_course_load(self, time_commitment: int) -> int:
|
| 546 |
+
if time_commitment <= 20:
|
| 547 |
+
return 3
|
| 548 |
+
if time_commitment <= 40:
|
| 549 |
+
return 4
|
| 550 |
+
return 5
|
| 551 |
+
|
| 552 |
+
def _identify_track(self, student: StudentProfile) -> str:
|
| 553 |
+
if not hasattr(self, 'embedding_model') or self.embedding_model is None:
|
| 554 |
+
combined = f"{student.career_goals.lower()} {' '.join(student.interests).lower()}"
|
| 555 |
+
if any(word in combined for word in ['ai', 'ml', 'machine learning', 'data']):
|
| 556 |
+
return "ai_ml"
|
| 557 |
+
if any(word in combined for word in ['systems', 'distributed', 'backend']):
|
| 558 |
+
return "systems"
|
| 559 |
+
if any(word in combined for word in ['security', 'cyber']):
|
| 560 |
+
return "security"
|
| 561 |
+
return "ai_ml"
|
| 562 |
+
|
| 563 |
+
profile_text = f"{student.career_goals} {' '.join(student.interests)}"
|
| 564 |
+
profile_emb = self.embedding_model.encode(profile_text, convert_to_tensor=True)
|
| 565 |
+
|
| 566 |
+
track_descriptions = {
|
| 567 |
+
"ai_ml": "artificial intelligence machine learning deep learning neural networks data science",
|
| 568 |
+
"systems": "operating systems distributed systems networks compilers databases performance backend",
|
| 569 |
+
"security": "cybersecurity cryptography network security ethical hacking vulnerabilities"
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
best_track, best_score = "ai_ml", -1.0
|
| 573 |
+
for track, description in track_descriptions.items():
|
| 574 |
+
track_emb = self.embedding_model.encode(description, convert_to_tensor=True)
|
| 575 |
+
score = float(util.cos_sim(profile_emb, track_emb))
|
| 576 |
+
if score > best_score:
|
| 577 |
+
best_score, best_track = score, track
|
| 578 |
+
|
| 579 |
+
return best_track
|
| 580 |
+
|
| 581 |
+
def _compute_semantic_scores(self, student: StudentProfile) -> Dict[str, float]:
|
| 582 |
+
query_text = f"{student.career_goals} {' '.join(student.interests)}"
|
| 583 |
+
query_emb = self.embedding_model.encode(query_text, convert_to_tensor=True)
|
| 584 |
+
similarities = util.cos_sim(query_emb, self.course_embeddings)[0]
|
| 585 |
+
return {cid: float(similarities[idx]) for idx, cid in enumerate(self.valid_courses)}
|
| 586 |
+
|
| 587 |
+
def _generate_explanation(self, student: StudentProfile, plan: Dict, track: str, plan_type: str) -> str:
|
| 588 |
+
return f"{plan_type.title()} plan for the {track} track, tailored to your goal of becoming a {student.career_goals}."
|
| 589 |
+
|
| 590 |
+
def validate_plan(self, plan: Dict, student: StudentProfile = None) -> Dict[str, List[str]]:
|
| 591 |
+
issues = {"errors": [], "warnings": [], "info": []}
|
| 592 |
+
all_courses = [course for year in plan.values() for sem in year.values() for course in sem if isinstance(sem, list)]
|
| 593 |
+
|
| 594 |
+
# Check for mixed tracks
|
| 595 |
+
for track_type, tracks in self.COURSE_TRACKS.items():
|
| 596 |
+
tracks_used = {name for name, courses in tracks.items() if any(c in all_courses for c in courses)}
|
| 597 |
+
if len(tracks_used) > 1:
|
| 598 |
+
issues["errors"].append(f"Mixed {track_type} tracks: {', '.join(tracks_used)}. Choose one sequence.")
|
| 599 |
+
|
| 600 |
+
# Validate prerequisites
|
| 601 |
+
completed_for_validation = set(student.completed_courses) if student else set()
|
| 602 |
+
for year in range(1, 5):
|
| 603 |
+
for sem in ["fall", "spring"]:
|
| 604 |
+
year_key = f"year_{year}"
|
| 605 |
+
sem_courses = plan.get(year_key, {}).get(sem, [])
|
| 606 |
+
for course in sem_courses:
|
| 607 |
+
if course in self.curriculum_graph:
|
| 608 |
+
prereqs = set(self.curriculum_graph.predecessors(course))
|
| 609 |
+
if not prereqs.issubset(self._get_completed_with_equivalents(completed_for_validation)):
|
| 610 |
+
missing = prereqs - completed_for_validation
|
| 611 |
+
issues["errors"].append(f"{course} in Year {year} {sem} is missing prereqs: {', '.join(missing)}")
|
| 612 |
+
completed_for_validation.update(sem_courses)
|
| 613 |
+
|
| 614 |
+
return issues
|
| 615 |
+
|
| 616 |
+
def _finalize_plan(self, plan: Dict, explanation: str, validation: Dict = None) -> Dict:
|
| 617 |
+
structured_plan = {
|
| 618 |
+
"reasoning": explanation,
|
| 619 |
+
"validation": validation or {"errors": [], "warnings": [], "info": []}
|
| 620 |
+
}
|
| 621 |
+
|
| 622 |
+
complexities = []
|
| 623 |
+
for year in range(1, 5):
|
| 624 |
+
year_key = f"year_{year}"
|
| 625 |
+
structured_plan[year_key] = {
|
| 626 |
+
"fall": plan.get(year_key, {}).get("fall", []),
|
| 627 |
+
"spring": plan.get(year_key, {}).get("spring", []),
|
| 628 |
+
"summer": "co-op" if year in [2, 3] else []
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
for sem in ["fall", "spring"]:
|
| 632 |
+
courses = structured_plan[year_key][sem]
|
| 633 |
+
if courses:
|
| 634 |
+
sem_complexity = sum(self.courses.get(c, {}).get('complexity', 50) for c in courses)
|
| 635 |
+
complexities.append(sem_complexity)
|
| 636 |
+
|
| 637 |
+
structured_plan["complexity_analysis"] = {
|
| 638 |
+
"average_semester_complexity": float(np.mean(complexities)) if complexities else 0,
|
| 639 |
+
"peak_semester_complexity": float(np.max(complexities)) if complexities else 0,
|
| 640 |
+
"total_complexity": float(np.sum(complexities)) if complexities else 0,
|
| 641 |
+
"balance_score (std_dev)": float(np.std(complexities)) if complexities else 0
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
structured_plan["metadata"] = {
|
| 645 |
+
"generated": datetime.now().isoformat(),
|
| 646 |
+
"valid": len(validation.get("errors", [])) == 0 if validation else True,
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
return {"pathway": structured_plan}
|
| 650 |
+
|
| 651 |
+
class CurriculumOptimizer(HybridOptimizer):
|
| 652 |
+
"""Compatibility wrapper"""
|
| 653 |
+
def __init__(self):
|
| 654 |
+
super().__init__()
|
| 655 |
+
|
| 656 |
+
def generate_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
|
| 657 |
+
return self.generate_enhanced_rule_plan(student, track_override)
|
src/neu_graph_clean10.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ebd3c024667aedd28338a61e7e87bdd78c7db3dce94fb1a920e11eb7bdc985d
|
| 3 |
+
size 156590
|
src/ui.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pickle
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import yaml
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Dict, Set, Optional
|
| 9 |
+
|
| 10 |
+
# Import the optimizer and visualizer
|
| 11 |
+
from curriculum_optimizer import HybridOptimizer, StudentProfile
|
| 12 |
+
from interactive_visualizer import CurriculumVisualizer
|
| 13 |
+
|
| 14 |
+
# --- Page Configuration ---
|
| 15 |
+
st.set_page_config(page_title="Curriculum Optimizer", layout="wide", initial_sidebar_state="expanded")
|
| 16 |
+
|
| 17 |
+
# Initialize session state
|
| 18 |
+
if "display_plan" not in st.session_state:
|
| 19 |
+
st.session_state.display_plan = None
|
| 20 |
+
if "metrics" not in st.session_state:
|
| 21 |
+
st.session_state.metrics = None
|
| 22 |
+
if "reasoning" not in st.session_state:
|
| 23 |
+
st.session_state.reasoning = ""
|
| 24 |
+
if "graph_data_loaded" not in st.session_state:
|
| 25 |
+
st.session_state.graph_data_loaded = False
|
| 26 |
+
if "last_profile" not in st.session_state:
|
| 27 |
+
st.session_state.last_profile = None
|
| 28 |
+
if "visualizer" not in st.session_state:
|
| 29 |
+
st.session_state.visualizer = None
|
| 30 |
+
if "selected_track" not in st.session_state:
|
| 31 |
+
st.session_state.selected_track = "general" # Default to general
|
| 32 |
+
|
| 33 |
+
# Title
|
| 34 |
+
st.title("🧑🎓 Next-Gen Curriculum Optimizer")
|
| 35 |
+
|
| 36 |
+
# --- Caching and Initialization ---
|
| 37 |
+
@st.cache_resource
|
| 38 |
+
def get_optimizer():
|
| 39 |
+
"""Loads and caches the main optimizer class and its models."""
|
| 40 |
+
try:
|
| 41 |
+
optimizer = HybridOptimizer()
|
| 42 |
+
optimizer.load_models()
|
| 43 |
+
return optimizer
|
| 44 |
+
except Exception as e:
|
| 45 |
+
st.error(f"Fatal error during model loading: {e}")
|
| 46 |
+
st.info("Please ensure you have the required libraries installed.")
|
| 47 |
+
st.stop()
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
optimizer = get_optimizer()
|
| 51 |
+
|
| 52 |
+
# --- DYNAMIC HELPER FUNCTIONS ---
|
| 53 |
+
|
| 54 |
+
def check_requirements_satisfaction(plan: Dict, track: str) -> Dict:
|
| 55 |
+
"""
|
| 56 |
+
Check which requirements are satisfied by the plan.
|
| 57 |
+
This is now dynamic based on the optimizer's config.
|
| 58 |
+
"""
|
| 59 |
+
if not optimizer:
|
| 60 |
+
return {}
|
| 61 |
+
|
| 62 |
+
all_courses = []
|
| 63 |
+
for year_key, year_data in plan.items():
|
| 64 |
+
if year_key.startswith("year_"):
|
| 65 |
+
all_courses.extend(year_data.get("fall", []))
|
| 66 |
+
all_courses.extend(year_data.get("spring", []))
|
| 67 |
+
all_courses_set = set(all_courses)
|
| 68 |
+
|
| 69 |
+
# Get the correct requirements dictionary
|
| 70 |
+
if track == "general":
|
| 71 |
+
req_data = {
|
| 72 |
+
"foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
|
| 73 |
+
"core": {"required": ["CS3000", "CS3500", "CS3650"]},
|
| 74 |
+
"math": {"required": ["MATH1341", "MATH1342"], "pick_1_from": ["MATH2331", "MATH3081"]}
|
| 75 |
+
}
|
| 76 |
+
elif track == "game_dev":
|
| 77 |
+
# Use ai_ml as a base for game_dev
|
| 78 |
+
req_data = optimizer.CONCENTRATION_REQUIREMENTS.get("ai_ml", {})
|
| 79 |
+
else:
|
| 80 |
+
req_data = optimizer.CONCENTRATION_REQUIREMENTS.get(track, {})
|
| 81 |
+
|
| 82 |
+
satisfaction_report = {}
|
| 83 |
+
for category, reqs in req_data.items():
|
| 84 |
+
report = {}
|
| 85 |
+
if "required" in reqs:
|
| 86 |
+
req_list = reqs["required"]
|
| 87 |
+
report["required"] = req_list
|
| 88 |
+
report["completed"] = list(all_courses_set & set(req_list))
|
| 89 |
+
report["is_satisfied"] = all_courses_set.issuperset(req_list)
|
| 90 |
+
|
| 91 |
+
for key, courses in reqs.items():
|
| 92 |
+
if key.startswith("pick_"):
|
| 93 |
+
try:
|
| 94 |
+
num_to_pick = int(key.split("_")[1])
|
| 95 |
+
except Exception:
|
| 96 |
+
num_to_pick = 1
|
| 97 |
+
|
| 98 |
+
completed_in_pick = list(all_courses_set & set(courses))
|
| 99 |
+
report[key] = {
|
| 100 |
+
"options": courses,
|
| 101 |
+
"completed": completed_in_pick,
|
| 102 |
+
"count": f"{len(completed_in_pick)} of {num_to_pick}",
|
| 103 |
+
"is_satisfied": len(completed_in_pick) >= num_to_pick
|
| 104 |
+
}
|
| 105 |
+
satisfaction_report[category] = report
|
| 106 |
+
|
| 107 |
+
return satisfaction_report
|
| 108 |
+
|
| 109 |
+
def export_plan_yaml(plan: Dict, profile: StudentProfile, validation: Dict = None, track: str = "general") -> str:
|
| 110 |
+
"""Export plan in structured YAML format for verification"""
|
| 111 |
+
|
| 112 |
+
# Build structured plan data
|
| 113 |
+
structured_plan = {
|
| 114 |
+
"student_profile": {
|
| 115 |
+
"name": profile.name if hasattr(profile, 'name') else "Student",
|
| 116 |
+
"gpa": profile.current_gpa,
|
| 117 |
+
"career_goal": profile.career_goals,
|
| 118 |
+
"interests": profile.interests,
|
| 119 |
+
"completed_courses": profile.completed_courses,
|
| 120 |
+
"time_commitment": profile.time_commitment,
|
| 121 |
+
"preferred_difficulty": profile.preferred_difficulty
|
| 122 |
+
},
|
| 123 |
+
"plan_metadata": {
|
| 124 |
+
"generated": datetime.now().isoformat(),
|
| 125 |
+
"track": track, # --- FIX: Now dynamic ---
|
| 126 |
+
"total_credits": 0,
|
| 127 |
+
"validation_status": "valid" if not validation.get("errors") else "has_errors"
|
| 128 |
+
},
|
| 129 |
+
"validation": validation if validation else {"errors": [], "warnings": []},
|
| 130 |
+
"semesters": [],
|
| 131 |
+
"course_details": {}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# Build semester list with full details
|
| 135 |
+
total_credits = 0
|
| 136 |
+
for year in range(1, 5):
|
| 137 |
+
year_key = f"year_{year}"
|
| 138 |
+
if year_key in plan:
|
| 139 |
+
# Fall
|
| 140 |
+
fall_courses = plan[year_key].get("fall", [])
|
| 141 |
+
if fall_courses:
|
| 142 |
+
semester_data = {"year": year, "term": "fall", "courses": []}
|
| 143 |
+
for course_id in fall_courses:
|
| 144 |
+
course_info = optimizer.courses.get(course_id, {})
|
| 145 |
+
course_detail = {
|
| 146 |
+
"id": course_id,
|
| 147 |
+
"name": course_info.get("name", "Unknown"),
|
| 148 |
+
"credits": course_info.get("maxCredits", 4),
|
| 149 |
+
"complexity": course_info.get("complexity", 0),
|
| 150 |
+
"prerequisites": list(optimizer.curriculum_graph.predecessors(course_id)) if course_id in optimizer.curriculum_graph else []
|
| 151 |
+
}
|
| 152 |
+
semester_data["courses"].append(course_detail)
|
| 153 |
+
total_credits += course_detail["credits"]
|
| 154 |
+
structured_plan["course_details"][course_id] = course_detail
|
| 155 |
+
|
| 156 |
+
semester_data["semester_credits"] = sum(c["credits"] for c in semester_data["courses"])
|
| 157 |
+
semester_data["semester_complexity"] = sum(c["complexity"] for c in semester_data["courses"])
|
| 158 |
+
structured_plan["semesters"].append(semester_data)
|
| 159 |
+
|
| 160 |
+
# Spring
|
| 161 |
+
spring_courses = plan[year_key].get("spring", [])
|
| 162 |
+
if spring_courses:
|
| 163 |
+
semester_data = {"year": year, "term": "spring", "courses": []}
|
| 164 |
+
for course_id in spring_courses:
|
| 165 |
+
course_info = optimizer.courses.get(course_id, {})
|
| 166 |
+
course_detail = {
|
| 167 |
+
"id": course_id,
|
| 168 |
+
"name": course_info.get("name", "Unknown"),
|
| 169 |
+
"credits": course_info.get("maxCredits", 4),
|
| 170 |
+
"complexity": course_info.get("complexity", 0),
|
| 171 |
+
"prerequisites": list(optimizer.curriculum_graph.predecessors(course_id)) if course_id in optimizer.curriculum_graph else []
|
| 172 |
+
}
|
| 173 |
+
semester_data["courses"].append(course_detail)
|
| 174 |
+
total_credits += course_detail["credits"]
|
| 175 |
+
structured_plan["course_details"][course_id] = course_detail
|
| 176 |
+
|
| 177 |
+
semester_data["semester_credits"] = sum(c["credits"] for c in semester_data["courses"])
|
| 178 |
+
semester_data["semester_complexity"] = sum(c["complexity"] for c in semester_data["courses"])
|
| 179 |
+
structured_plan["semesters"].append(semester_data)
|
| 180 |
+
|
| 181 |
+
# Add summer/co-op
|
| 182 |
+
if year in [2, 3]:
|
| 183 |
+
structured_plan["semesters"].append({
|
| 184 |
+
"year": year, "term": "summer", "activity": "co-op", "courses": []
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
structured_plan["plan_metadata"]["total_credits"] = total_credits
|
| 188 |
+
|
| 189 |
+
# Calculate requirement satisfaction
|
| 190 |
+
# --- FIX: Pass the dynamic track ---
|
| 191 |
+
requirements_met = check_requirements_satisfaction(plan, track=track)
|
| 192 |
+
structured_plan["requirements_satisfaction"] = requirements_met
|
| 193 |
+
|
| 194 |
+
return yaml.dump(structured_plan, default_flow_style=False, sort_keys=False)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --- UI TABS ---
|
| 198 |
+
tab1, tab2, tab3 = st.tabs(["📝 Plan Generator", "🗺️ Curriculum Map", "📊 Analytics"])
|
| 199 |
+
|
| 200 |
+
with tab1:
|
| 201 |
+
# --- SIDEBAR FOR STUDENT PROFILE ---
|
| 202 |
+
with st.sidebar:
|
| 203 |
+
st.header("Student Profile")
|
| 204 |
+
name = st.text_input("Name", "John, son of Jane")
|
| 205 |
+
gpa = st.slider("GPA", 0.0, 4.0, 3.0, 0.1)
|
| 206 |
+
career_goal = st.text_area("Career Goal", " ")
|
| 207 |
+
interests = st.text_input("Interests (comma-separated)", " ")
|
| 208 |
+
learning_style = st.selectbox("Learning Style", ["Visual", "Hands-on", "Auditory"])
|
| 209 |
+
time_commit = st.number_input("Weekly Study Hours", 10, 60, 40, 5)
|
| 210 |
+
difficulty = st.selectbox("Preferred Difficulty", ["easy", "moderate", "challenging"])
|
| 211 |
+
completed_courses_input = st.text_area("Completed Courses (comma-separated)", " ")
|
| 212 |
+
|
| 213 |
+
# Show profile impact
|
| 214 |
+
st.markdown("---")
|
| 215 |
+
st.markdown("**Profile Impact:**")
|
| 216 |
+
if time_commit < 20:
|
| 217 |
+
st.info("🕒 Part-time load (3 courses/semester)")
|
| 218 |
+
elif time_commit >= 40:
|
| 219 |
+
st.info("🔥 Intensive load (up to 5 courses/semester)")
|
| 220 |
+
else:
|
| 221 |
+
st.info("📚 Standard load (4 courses/semester)")
|
| 222 |
+
|
| 223 |
+
if difficulty == "easy":
|
| 224 |
+
st.info("😌 Focuses on foundational courses")
|
| 225 |
+
elif difficulty == "challenging":
|
| 226 |
+
st.info("🚀 Includes advanced/specialized courses")
|
| 227 |
+
else:
|
| 228 |
+
st.info("⚖️ Balanced difficulty progression")
|
| 229 |
+
|
| 230 |
+
# --- MAIN PAGE CONTENT ---
|
| 231 |
+
|
| 232 |
+
# 1. LOAD DATA
|
| 233 |
+
st.subheader("1. Load Curriculum Data")
|
| 234 |
+
uploaded_file = st.file_uploader("Upload `.pkl` file in the files section of this project", type=["pkl"])
|
| 235 |
+
|
| 236 |
+
if uploaded_file and not st.session_state.graph_data_loaded:
|
| 237 |
+
with st.spinner("Loading curriculum data and preparing embeddings..."):
|
| 238 |
+
try:
|
| 239 |
+
graph_data = pickle.load(uploaded_file)
|
| 240 |
+
optimizer.load_data(graph_data)
|
| 241 |
+
st.session_state.visualizer = CurriculumVisualizer(graph_data)
|
| 242 |
+
st.session_state.graph_data = graph_data
|
| 243 |
+
st.session_state.graph_data_loaded = True
|
| 244 |
+
st.success(f"Successfully loaded and processed '{uploaded_file.name}'!")
|
| 245 |
+
time.sleep(1)
|
| 246 |
+
st.rerun()
|
| 247 |
+
except Exception as e:
|
| 248 |
+
st.error(f"Error processing .pkl file: {e}")
|
| 249 |
+
st.session_state.graph_data_loaded = False
|
| 250 |
+
elif st.session_state.graph_data_loaded:
|
| 251 |
+
st.success("Curriculum data is loaded and ready.")
|
| 252 |
+
|
| 253 |
+
# 2. SELECT TRACK (NEW SECTION)
|
| 254 |
+
st.subheader("2. Select a Specialization")
|
| 255 |
+
if not st.session_state.graph_data_loaded:
|
| 256 |
+
st.info("Please load a curriculum file first.")
|
| 257 |
+
else:
|
| 258 |
+
# Map user-friendly names to the internal keys
|
| 259 |
+
track_options = {
|
| 260 |
+
"general": "🤖 General CS (Broadest Focus)",
|
| 261 |
+
"ai_ml": "🧠 Artificial Intelligence & ML",
|
| 262 |
+
"security": "🔒 Cybersecurity",
|
| 263 |
+
"systems": "⚙️ Systems & Networks",
|
| 264 |
+
"game_dev": "🎮 Game Design & Development"
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
selected_track_key = st.selectbox(
|
| 268 |
+
"Choose your focus area (optional):",
|
| 269 |
+
options=track_options.keys(),
|
| 270 |
+
format_func=lambda key: track_options[key], # Shows the friendly name
|
| 271 |
+
index=0 # Default to "General"
|
| 272 |
+
)
|
| 273 |
+
st.session_state.selected_track = selected_track_key
|
| 274 |
+
|
| 275 |
+
# 3. GENERATE PLAN
|
| 276 |
+
st.subheader("3. Generate a Plan")
|
| 277 |
+
if not st.session_state.graph_data_loaded:
|
| 278 |
+
st.info("Please load a curriculum file above to enable plan generation.")
|
| 279 |
+
else:
|
| 280 |
+
# Create student profile
|
| 281 |
+
profile = StudentProfile(
|
| 282 |
+
completed_courses=[c.strip().upper() for c in completed_courses_input.split(',') if c.strip()],
|
| 283 |
+
current_gpa=gpa,
|
| 284 |
+
interests=[i.strip() for i in interests.split(',') if i.strip()],
|
| 285 |
+
career_goals=career_goal,
|
| 286 |
+
learning_style=learning_style,
|
| 287 |
+
time_commitment=time_commit,
|
| 288 |
+
preferred_difficulty=difficulty
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Get the selected track from session state
|
| 292 |
+
selected_track = st.session_state.get("selected_track", "general")
|
| 293 |
+
|
| 294 |
+
# Check if profile or track changed
|
| 295 |
+
profile_changed = (st.session_state.last_profile != profile) or \
|
| 296 |
+
(st.session_state.last_track != selected_track)
|
| 297 |
+
|
| 298 |
+
if profile_changed:
|
| 299 |
+
st.session_state.last_profile = profile
|
| 300 |
+
st.session_state.last_track = selected_track
|
| 301 |
+
|
| 302 |
+
col1, col2, col3 = st.columns(3)
|
| 303 |
+
|
| 304 |
+
if col1.button("🧠 AI-Optimized Plan", use_container_width=True, type="primary"):
|
| 305 |
+
with st.spinner(f"🚀 Performing AI-optimization for '{track_options[selected_track]}' track..."):
|
| 306 |
+
start_time = time.time()
|
| 307 |
+
# --- FIX: Pass selected_track ---
|
| 308 |
+
result = optimizer.generate_llm_plan(profile, selected_track)
|
| 309 |
+
generation_time = time.time() - start_time
|
| 310 |
+
|
| 311 |
+
plan_raw = result.get('pathway', {})
|
| 312 |
+
st.session_state.reasoning = plan_raw.get("reasoning", "")
|
| 313 |
+
st.session_state.metrics = plan_raw.get("complexity_analysis", {})
|
| 314 |
+
st.session_state.display_plan = plan_raw
|
| 315 |
+
st.session_state.plan_type = "AI-Optimized"
|
| 316 |
+
st.session_state.generation_time = generation_time
|
| 317 |
+
st.success(f"🎉 AI-optimized plan generated in {generation_time:.1f}s!")
|
| 318 |
+
|
| 319 |
+
if col2.button("⚡ Smart Rule-Based Plan", use_container_width=True):
|
| 320 |
+
with st.spinner(f"Generating rule-based plan for '{track_options[selected_track]}' track..."):
|
| 321 |
+
start_time = time.time()
|
| 322 |
+
# --- FIX: Pass selected_track ---
|
| 323 |
+
result = optimizer.generate_simple_plan(profile, selected_track)
|
| 324 |
+
generation_time = time.time() - start_time
|
| 325 |
+
|
| 326 |
+
plan_raw = result.get('pathway', {})
|
| 327 |
+
st.session_state.reasoning = plan_raw.get("reasoning", "")
|
| 328 |
+
st.session_state.metrics = plan_raw.get("complexity_analysis", {})
|
| 329 |
+
st.session_state.display_plan = plan_raw
|
| 330 |
+
st.session_state.plan_type = "Smart Rule-Based"
|
| 331 |
+
st.session_state.generation_time = generation_time
|
| 332 |
+
st.success(f"🎉 Smart rule-based plan generated in {generation_time:.1f}s!")
|
| 333 |
+
|
| 334 |
+
if col3.button("🔄 Clear Plan", use_container_width=True):
|
| 335 |
+
st.session_state.display_plan = None
|
| 336 |
+
st.session_state.metrics = None
|
| 337 |
+
st.session_state.reasoning = ""
|
| 338 |
+
st.rerun()
|
| 339 |
+
|
| 340 |
+
# Show profile change notification
|
| 341 |
+
if st.session_state.display_plan and profile_changed:
|
| 342 |
+
st.warning("⚠️ Student profile or track changed! Generate a new plan to see updated recommendations.")
|
| 343 |
+
|
| 344 |
+
# DISPLAY RESULTS
|
| 345 |
+
if st.session_state.display_plan:
|
| 346 |
+
st.subheader(f"📚 {st.session_state.get('plan_type', 'Optimized')} Degree Plan")
|
| 347 |
+
|
| 348 |
+
# Display generation info
|
| 349 |
+
col_info1, col_info2, col_info3 = st.columns(3)
|
| 350 |
+
with col_info1:
|
| 351 |
+
st.metric("Generation Time", f"{st.session_state.get('generation_time', 0):.1f}s")
|
| 352 |
+
with col_info2:
|
| 353 |
+
st.metric("Plan Type", st.session_state.get('plan_type', 'Unknown'))
|
| 354 |
+
with col_info3:
|
| 355 |
+
if time_commit < 20:
|
| 356 |
+
load_type = "Part-time"
|
| 357 |
+
elif time_commit >= 40:
|
| 358 |
+
load_type = "Intensive"
|
| 359 |
+
else:
|
| 360 |
+
load_type = "Standard"
|
| 361 |
+
st.metric("Course Load", load_type)
|
| 362 |
+
|
| 363 |
+
# Display reasoning and metrics
|
| 364 |
+
if st.session_state.reasoning or st.session_state.metrics:
|
| 365 |
+
st.markdown("##### 📊 Plan Analysis")
|
| 366 |
+
|
| 367 |
+
if st.session_state.reasoning:
|
| 368 |
+
st.info(f"**Strategy:** {st.session_state.reasoning}")
|
| 369 |
+
|
| 370 |
+
if st.session_state.metrics:
|
| 371 |
+
m = st.session_state.metrics
|
| 372 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 373 |
+
|
| 374 |
+
c1.metric("Avg Complexity", f"{m.get('average_semester_complexity', 0):.1f}")
|
| 375 |
+
c2.metric("Peak Complexity", f"{m.get('peak_semester_complexity', 0):.1f}")
|
| 376 |
+
c3.metric("Total Complexity", f"{m.get('total_complexity', 0):.0f}")
|
| 377 |
+
c4.metric("Balance Score", f"{m.get('balance_score (std_dev)', 0):.2f}")
|
| 378 |
+
|
| 379 |
+
st.divider()
|
| 380 |
+
|
| 381 |
+
# Display the actual plan
|
| 382 |
+
plan = st.session_state.display_plan
|
| 383 |
+
total_courses = 0
|
| 384 |
+
|
| 385 |
+
for year_num in range(1, 5):
|
| 386 |
+
year_key = f"year_{year_num}"
|
| 387 |
+
year_data = plan.get(year_key, {})
|
| 388 |
+
|
| 389 |
+
st.markdown(f"### Year {year_num}")
|
| 390 |
+
col_fall, col_spring, col_summer = st.columns(3)
|
| 391 |
+
|
| 392 |
+
# Fall semester
|
| 393 |
+
with col_fall:
|
| 394 |
+
fall_courses = year_data.get("fall", [])
|
| 395 |
+
st.markdown("**🍂 Fall Semester**")
|
| 396 |
+
if fall_courses:
|
| 397 |
+
for course_id in fall_courses:
|
| 398 |
+
if course_id in optimizer.courses:
|
| 399 |
+
course_data = optimizer.courses[course_id]
|
| 400 |
+
course_name = course_data.get("name", course_id)
|
| 401 |
+
st.write(f"• **{course_id}**: {course_name}")
|
| 402 |
+
total_courses += 1
|
| 403 |
+
else:
|
| 404 |
+
st.write(f"• {course_id}")
|
| 405 |
+
total_courses += 1
|
| 406 |
+
else:
|
| 407 |
+
st.write("*No courses scheduled*")
|
| 408 |
+
|
| 409 |
+
# Spring semester
|
| 410 |
+
with col_spring:
|
| 411 |
+
spring_courses = year_data.get("spring", [])
|
| 412 |
+
st.markdown("**🌸 Spring Semester**")
|
| 413 |
+
if spring_courses:
|
| 414 |
+
for course_id in spring_courses:
|
| 415 |
+
if course_id in optimizer.courses:
|
| 416 |
+
course_data = optimizer.courses[course_id]
|
| 417 |
+
course_name = course_data.get("name", course_id)
|
| 418 |
+
st.write(f"• **{course_id}**: {course_name}")
|
| 419 |
+
total_courses += 1
|
| 420 |
+
else:
|
| 421 |
+
st.write(f"• {course_id}")
|
| 422 |
+
total_courses += 1
|
| 423 |
+
else:
|
| 424 |
+
st.write("*No courses scheduled*")
|
| 425 |
+
|
| 426 |
+
# Summer
|
| 427 |
+
with col_summer:
|
| 428 |
+
summer = year_data.get("summer", [])
|
| 429 |
+
st.markdown("**☀️ Summer**")
|
| 430 |
+
if summer == "co-op":
|
| 431 |
+
st.write("🏢 *Co-op Experience*")
|
| 432 |
+
elif summer:
|
| 433 |
+
# This case isn't really used by the optimizer, but good to have
|
| 434 |
+
st.write("*Summer Classes*")
|
| 435 |
+
else:
|
| 436 |
+
st.write("*Break*")
|
| 437 |
+
|
| 438 |
+
# Summary and export
|
| 439 |
+
st.divider()
|
| 440 |
+
col_export1, col_export2 = st.columns(2)
|
| 441 |
+
|
| 442 |
+
with col_export1:
|
| 443 |
+
st.metric("Total Courses", total_courses)
|
| 444 |
+
|
| 445 |
+
with col_export2:
|
| 446 |
+
col_yaml, col_json = st.columns(2)
|
| 447 |
+
|
| 448 |
+
with col_yaml:
|
| 449 |
+
# --- FIX: Get validation from the plan object, DO NOT re-run validate_plan() ---
|
| 450 |
+
validation = st.session_state.display_plan.get("validation", {"errors": [], "warnings": []})
|
| 451 |
+
|
| 452 |
+
yaml_data = export_plan_yaml(
|
| 453 |
+
st.session_state.display_plan,
|
| 454 |
+
profile,
|
| 455 |
+
validation,
|
| 456 |
+
st.session_state.get("selected_track", "general") # Pass track
|
| 457 |
+
)
|
| 458 |
+
st.download_button(
|
| 459 |
+
label="📥 Export as YAML",
|
| 460 |
+
data=yaml_data,
|
| 461 |
+
file_name=f"curriculum_plan_{name.replace(' ', '_')}.yaml",
|
| 462 |
+
mime="text/yaml",
|
| 463 |
+
use_container_width=True
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
with col_json:
|
| 467 |
+
export_data = {
|
| 468 |
+
"student_profile": {
|
| 469 |
+
"name": name, "gpa": gpa, "career_goals": career_goal,
|
| 470 |
+
"interests": interests, "learning_style": learning_style,
|
| 471 |
+
"time_commitment": time_commit, "preferred_difficulty": difficulty,
|
| 472 |
+
"completed_courses": completed_courses_input
|
| 473 |
+
},
|
| 474 |
+
"plan": st.session_state.display_plan,
|
| 475 |
+
"metrics": st.session_state.metrics,
|
| 476 |
+
"generation_info": {
|
| 477 |
+
"plan_type": st.session_state.get('plan_type', 'Unknown'),
|
| 478 |
+
"generation_time": st.session_state.get('generation_time', 0),
|
| 479 |
+
"selected_track": st.session_state.get("selected_track", "general")
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
plan_json = json.dumps(export_data, indent=2)
|
| 483 |
+
st.download_button(
|
| 484 |
+
label="📥 Export as JSON",
|
| 485 |
+
data=plan_json,
|
| 486 |
+
file_name=f"curriculum_plan_{name.replace(' ', '_')}.json",
|
| 487 |
+
mime="application/json",
|
| 488 |
+
use_container_width=True
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# --- TAB 2: CURRICULUM MAP ---
|
| 492 |
+
with tab2:
|
| 493 |
+
st.subheader("🗺️ Interactive Curriculum Dependency Graph")
|
| 494 |
+
|
| 495 |
+
if not st.session_state.graph_data_loaded:
|
| 496 |
+
st.info("Please load curriculum data in the Plan Generator tab first.")
|
| 497 |
+
else:
|
| 498 |
+
# Create visualization
|
| 499 |
+
if st.session_state.visualizer:
|
| 500 |
+
critical_path = st.session_state.visualizer.find_critical_path()
|
| 501 |
+
if critical_path:
|
| 502 |
+
st.info(f"Global Critical Path ({len(critical_path)} courses): {' → '.join(critical_path[:7])}...")
|
| 503 |
+
|
| 504 |
+
# Create the plot
|
| 505 |
+
fig = st.session_state.visualizer.create_interactive_plot(critical_path)
|
| 506 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 507 |
+
|
| 508 |
+
# Legend
|
| 509 |
+
with st.expander("📖 How to Read This Graph"):
|
| 510 |
+
st.markdown("""
|
| 511 |
+
**Node (Circle) Size**: Blocking factor - larger circles block more future courses
|
| 512 |
+
**Node Color**: Complexity score - darker = more complex
|
| 513 |
+
**Lines**: Prerequisite relationships
|
| 514 |
+
**Red Path**: Critical path (longest chain)
|
| 515 |
+
**Hover over nodes**: See detailed metrics for each course
|
| 516 |
+
""")
|
| 517 |
+
|
| 518 |
+
# --- TAB 3: ANALYTICS ---
|
| 519 |
+
with tab3:
|
| 520 |
+
st.subheader("📊 Curriculum Analytics Dashboard")
|
| 521 |
+
|
| 522 |
+
if not st.session_state.graph_data_loaded:
|
| 523 |
+
st.info("Please load curriculum data in the Plan Generator tab first.")
|
| 524 |
+
else:
|
| 525 |
+
# Overall metrics
|
| 526 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 527 |
+
|
| 528 |
+
graph = st.session_state.graph_data
|
| 529 |
+
total_courses = graph.number_of_nodes()
|
| 530 |
+
total_prereqs = graph.number_of_edges()
|
| 531 |
+
|
| 532 |
+
col1.metric("Total Courses", total_courses)
|
| 533 |
+
col2.metric("Total Prerequisites", total_prereqs)
|
| 534 |
+
col3.metric("Avg Prerequisites", f"{total_prereqs/total_courses:.1f}")
|
| 535 |
+
|
| 536 |
+
if st.session_state.visualizer:
|
| 537 |
+
total_complexity = sum(
|
| 538 |
+
st.session_state.visualizer.calculate_metrics(n)['complexity']
|
| 539 |
+
for n in graph.nodes()
|
| 540 |
+
)
|
| 541 |
+
col4.metric("Curriculum Complexity", f"{total_complexity:,.0f}")
|
| 542 |
+
|
| 543 |
+
st.divider()
|
| 544 |
+
|
| 545 |
+
# Most complex courses
|
| 546 |
+
col1, col2 = st.columns(2)
|
| 547 |
+
|
| 548 |
+
with col1:
|
| 549 |
+
st.subheader("Most Complex Courses")
|
| 550 |
+
if st.session_state.visualizer:
|
| 551 |
+
complexities = []
|
| 552 |
+
for node in graph.nodes():
|
| 553 |
+
metrics = st.session_state.visualizer.calculate_metrics(node)
|
| 554 |
+
complexities.append({
|
| 555 |
+
'course': node,
|
| 556 |
+
'name': graph.nodes[node].get('name', ''),
|
| 557 |
+
'complexity': metrics['complexity'],
|
| 558 |
+
'blocking': metrics['blocking']
|
| 559 |
+
})
|
| 560 |
+
|
| 561 |
+
complexities.sort(key=lambda x: x['complexity'], reverse=True)
|
| 562 |
+
|
| 563 |
+
for item in complexities[:10]:
|
| 564 |
+
st.write(f"**{item['course']}**: {item['name']}")
|
| 565 |
+
prog_col1, prog_col2 = st.columns([3, 1])
|
| 566 |
+
with prog_col1:
|
| 567 |
+
st.progress(min(item['complexity']/100, 1.0)) # Adjusted scale
|
| 568 |
+
with prog_col2:
|
| 569 |
+
st.caption(f"Blocks: {item['blocking']}")
|
| 570 |
+
|
| 571 |
+
with col2:
|
| 572 |
+
st.subheader("Bottleneck Courses")
|
| 573 |
+
st.caption("(High blocking factor)")
|
| 574 |
+
|
| 575 |
+
if st.session_state.visualizer:
|
| 576 |
+
bottlenecks = sorted(complexities, key=lambda x: x['blocking'], reverse=True)
|
| 577 |
+
|
| 578 |
+
for item in bottlenecks[:10]:
|
| 579 |
+
st.write(f"**{item['course']}**: {item['name']}")
|
| 580 |
+
st.info(f"Blocks {item['blocking']} future courses")
|
| 581 |
+
|
| 582 |
+
# Plan vs Global Comparison
|
| 583 |
+
if st.session_state.display_plan:
|
| 584 |
+
st.divider()
|
| 585 |
+
st.subheader("📊 Metric System Comparison")
|
| 586 |
+
st.caption("Comparing metrics for the entire curriculum vs. metrics only within your generated plan.")
|
| 587 |
+
|
| 588 |
+
plan_courses: Set[str] = set()
|
| 589 |
+
for year_key, year_data in st.session_state.display_plan.items():
|
| 590 |
+
if year_key.startswith("year_"):
|
| 591 |
+
plan_courses.update(year_data.get("fall", []))
|
| 592 |
+
plan_courses.update(year_data.get("spring", []))
|
| 593 |
+
|
| 594 |
+
comparison = st.session_state.visualizer.compare_metric_systems(plan_courses)
|
| 595 |
+
|
| 596 |
+
col1, col2 = st.columns(2)
|
| 597 |
+
|
| 598 |
+
with col1:
|
| 599 |
+
st.metric(
|
| 600 |
+
"Critical Path Match",
|
| 601 |
+
"✅ Yes" if comparison['critical_path_match'] else "❌ No"
|
| 602 |
+
)
|
| 603 |
+
st.caption("Global critical path (first 5):")
|
| 604 |
+
st.code(' → '.join(comparison['global_critical']))
|
| 605 |
+
|
| 606 |
+
with col2:
|
| 607 |
+
st.metric(
|
| 608 |
+
"Major Metric Differences",
|
| 609 |
+
len(comparison['major_differences'])
|
| 610 |
+
)
|
| 611 |
+
st.caption("Plan-specific critical path (first 5):")
|
| 612 |
+
st.code(' → '.join(comparison['plan_critical']))
|
| 613 |
+
|
| 614 |
+
if comparison['major_differences']:
|
| 615 |
+
with st.expander(f"View {len(comparison['major_differences'])} courses with >50% metric difference"):
|
| 616 |
+
for diff in comparison['major_differences']:
|
| 617 |
+
st.write(f"**{diff['course']}**: Global blocking={diff['global_blocking']}, Plan blocking={diff['plan_blocking']}")
|
| 618 |
+
|
| 619 |
+
# Footer
|
| 620 |
+
st.divider()
|
| 621 |
+
st.caption("🚀 Powered by Students, For Students")
|