ckharche commited on
Commit
5360228
·
verified ·
1 Parent(s): 1a5edf3

added option to choose tracks

Browse files
src/curriculum_analyzer.py CHANGED
@@ -1,11 +1,19 @@
 
1
  """
2
- Fixed Curriculum Analyzer - Better handling of incomplete data
 
 
 
 
 
 
3
  """
4
  import pickle
5
  import argparse
6
  import networkx as nx
7
  import re
8
  from typing import Set, Dict
 
9
 
10
  def get_course_level(cid):
11
  """Extracts the numerical part of a course ID for level checking."""
@@ -13,13 +21,27 @@ def get_course_level(cid):
13
  return int(match.group(0)) if match else 9999
14
 
15
  class CurriculumAnalyzer:
 
 
 
 
 
 
 
 
 
 
16
  def __init__(self, graph_path, courses_path):
17
  print("📚 Loading raw curriculum data...")
18
- with open(graph_path, 'rb') as f:
19
- self.graph = pickle.load(f)
20
- with open(courses_path, 'rb') as f:
21
- self.courses = pickle.load(f)
22
-
 
 
 
 
23
  # Merge course data into graph nodes
24
  for course_id, course_data in self.courses.items():
25
  if self.graph.has_node(course_id):
@@ -28,44 +50,65 @@ class CurriculumAnalyzer:
28
  print(f"✅ Loaded {self.graph.number_of_nodes()} courses, {self.graph.number_of_edges()} edges")
29
 
30
  def pre_filter_graph(self):
31
- """Keeps only relevant subjects and removes labs/high-level courses."""
 
 
 
32
  print("\n🧹 Pre-filtering graph...")
33
 
34
- KEEP_SUBJECTS = {"CS", "DS", "IS", "CY", "MATH", "PHYS", "ENGW", "STAT", "EECE"}
35
-
36
  nodes_to_remove = set()
37
  for node, data in self.graph.nodes(data=True):
38
  subject = data.get('subject', '')
39
  name = data.get('name', '').lower()
40
  level = get_course_level(node)
41
 
42
- # Remove if:
43
- # - Not in whitelist
44
- # - Too advanced (5000+)
45
- # - Lab/recitation/etc
46
- if (subject not in KEEP_SUBJECTS or
47
- level >= 5000 or
48
- any(skip in name for skip in ['lab', 'recitation', 'seminar', 'practicum', 'co-op'])):
 
 
 
 
49
  nodes_to_remove.add(node)
50
 
 
51
  self.graph.remove_nodes_from(nodes_to_remove)
52
- print(f"✅ Removed {len(nodes_to_remove)} irrelevant courses")
53
- print(f" Remaining: {self.graph.number_of_nodes()} courses")
 
 
54
 
55
  def fix_chains(self):
56
  """Adds critical prerequisite chains that might be missing."""
57
  print("\n🔗 Validating and fixing critical prerequisite chains...")
58
 
59
  critical_chains = {
 
60
  ("CS1800", "CS2800", "Discrete → Logic"),
61
  ("CS2500", "CS2510", "Fundies 1 → Fundies 2"),
 
62
  ("CS2510", "CS3500", "Fundies 2 → OOD"),
63
  ("CS2510", "CS3000", "Fundies 2 → Algorithms"),
64
- ("CS3000", "CS4100", "AlgorithmsAI"), # NEW
65
- ("MATH1341", "MATH1342", "Calc 1 → Calc 2"),
 
 
 
 
 
 
 
 
66
  ("DS2000", "DS2500", "Prog w/ Data → Intermediate"),
67
  ("DS2500", "DS3500", "Intermediate → Advanced"),
68
- ("DS3500", "DS4400", "Advanced → ML1"), # NEW
 
 
 
69
  }
70
 
71
  added = 0
@@ -80,11 +123,18 @@ class CurriculumAnalyzer:
80
  print(" ✅ All critical chains present")
81
 
82
  def remove_spurious_chains(self):
83
- """Removes known incorrect prerequisite edges."""
 
 
 
84
  print("\n🗑️ Removing spurious prerequisite chains...")
85
 
 
86
  spurious_chains = {
87
- ("MATH1365", "CS2800"), # Not a real prereq
 
 
 
88
  }
89
 
90
  removed = 0
@@ -102,6 +152,7 @@ class CurriculumAnalyzer:
102
  print("\n🧮 Calculating complexity scores...")
103
 
104
  for node in self.graph.nodes():
 
105
  in_degree = self.graph.in_degree(node)
106
  out_degree = self.graph.out_degree(node)
107
 
@@ -115,13 +166,14 @@ class CurriculumAnalyzer:
115
  """Check if all critical courses exist in the graph."""
116
  print("\n🎯 Validating critical course coverage...")
117
 
 
118
  required_courses = {
119
  "foundations": {"CS1800", "CS2500", "CS2510", "CS2800"},
120
- "core": {"CS3000", "CS3500", "CS3650", "CS3700", "CS3200"},
121
  "ai_ml": {"CS4100", "DS4400", "CS4120", "DS4420", "CS4180", "DS4440"},
122
- "systems": {"CS4730", "CS4400", "CS4500"}, # Removed often-missing courses
123
  "security": {"CY2550", "CY3740", "CY4740", "CY4760"},
124
- "math": {"MATH1341", "MATH1342", "MATH2331", "MATH3081"}, # No STAT courses at NEU
125
  }
126
 
127
  missing = {}
@@ -150,14 +202,12 @@ class CurriculumAnalyzer:
150
  f.write(f"Total courses: {self.graph.number_of_nodes()}\n")
151
  f.write(f"Total prerequisites: {self.graph.number_of_edges()}\n\n")
152
 
153
- # Subject breakdown
154
- from collections import defaultdict
155
  subject_counts = defaultdict(int)
156
  for node in self.graph.nodes():
157
  subject = self.graph.nodes[node].get('subject', 'UNKNOWN')
158
  subject_counts[subject] += 1
159
 
160
- f.write("Subject breakdown:\n")
161
  for subject in sorted(subject_counts.keys()):
162
  f.write(f" {subject}: {subject_counts[subject]}\n")
163
 
@@ -183,8 +233,8 @@ def main(args):
183
 
184
  if __name__ == "__main__":
185
  parser = argparse.ArgumentParser(description="NEU Curriculum Analyzer - Cleans and validates data")
186
- parser.add_argument('--graph', required=True, help="Path to RAW curriculum graph")
187
- parser.add_argument('--courses', required=True, help="Path to RAW courses data")
188
- parser.add_argument('--output-graph', default='neu_graph_clean.pkl', help="Output path")
189
  args = parser.parse_args()
190
  main(args)
 
1
+ #!/usr/bin/env python3
2
  """
3
+ FIXED Curriculum Analyzer - Production Version
4
+
5
+ Synchronized with optimizer logic:
6
+ 1. Filters subjects to ONLY: CS, DS, CY, MATH, PHYS, ENGW.
7
+ 2. Removes IS, EECE, STAT, and other irrelevant subjects.
8
+ 3. ADDS exception for undergrad-accessible 5000-level courses (CS5700).
9
+ 4. FIXES bad prerequisite data (e.g., CS2500 -> CS2800).
10
  """
11
  import pickle
12
  import argparse
13
  import networkx as nx
14
  import re
15
  from typing import Set, Dict
16
+ from collections import defaultdict
17
 
18
  def get_course_level(cid):
19
  """Extracts the numerical part of a course ID for level checking."""
 
21
  return int(match.group(0)) if match else 9999
22
 
23
  class CurriculumAnalyzer:
24
+
25
+ # --- FIX 1: DEFINE LISTS THAT MATCH THE OPTIMIZER ---
26
+
27
+ # Subjects the optimizer is programmed to understand.
28
+ # ENGW/PHYS are needed only for hardcoded Year 1.
29
+ KEEP_SUBJECTS = {"CS", "DS", "CY", "MATH", "PHYS", "ENGW"}
30
+
31
+ # 5000-level courses the optimizer explicitly allows.
32
+ UNDERGRAD_ACCESSIBLE_GRAD = {"CS5700", "CY5700", "DS5110", "CS5010"}
33
+
34
  def __init__(self, graph_path, courses_path):
35
  print("📚 Loading raw curriculum data...")
36
+ try:
37
+ with open(graph_path, 'rb') as f:
38
+ self.graph = pickle.load(f)
39
+ with open(courses_path, 'rb') as f:
40
+ self.courses = pickle.load(f)
41
+ except Exception as e:
42
+ print(f"❌ ERROR: Could not load files. {e}")
43
+ exit(1)
44
+
45
  # Merge course data into graph nodes
46
  for course_id, course_data in self.courses.items():
47
  if self.graph.has_node(course_id):
 
50
  print(f"✅ Loaded {self.graph.number_of_nodes()} courses, {self.graph.number_of_edges()} edges")
51
 
52
  def pre_filter_graph(self):
53
+ """
54
+ --- FIX 2: IMPLEMENTS STRICT FILTERING ---
55
+ Keeps only relevant subjects and removes labs/high-level courses.
56
+ """
57
  print("\n🧹 Pre-filtering graph...")
58
 
 
 
59
  nodes_to_remove = set()
60
  for node, data in self.graph.nodes(data=True):
61
  subject = data.get('subject', '')
62
  name = data.get('name', '').lower()
63
  level = get_course_level(node)
64
 
65
+ # Check for removal
66
+ is_irrelevant_subject = subject not in self.KEEP_SUBJECTS
67
+ is_lab_or_seminar = any(skip in name for skip in ['lab', 'recitation', 'seminar', 'practicum', 'co-op'])
68
+
69
+ # Grad-level check
70
+ is_grad_level = level >= 5000
71
+ is_allowed_grad = node in self.UNDERGRAD_ACCESSIBLE_GRAD
72
+
73
+ if (is_irrelevant_subject or
74
+ is_lab_or_seminar or
75
+ (is_grad_level and not is_allowed_grad)): # <-- Bug fix
76
  nodes_to_remove.add(node)
77
 
78
+ original_count = self.graph.number_of_nodes()
79
  self.graph.remove_nodes_from(nodes_to_remove)
80
+
81
+ print(f" Removed {len(nodes_to_remove)} irrelevant courses (IS, EECE, etc.)")
82
+ print(f" Original nodes: {original_count}")
83
+ print(f" Remaining nodes: {self.graph.number_of_nodes()}")
84
 
85
  def fix_chains(self):
86
  """Adds critical prerequisite chains that might be missing."""
87
  print("\n🔗 Validating and fixing critical prerequisite chains...")
88
 
89
  critical_chains = {
90
+ # Foundations
91
  ("CS1800", "CS2800", "Discrete → Logic"),
92
  ("CS2500", "CS2510", "Fundies 1 → Fundies 2"),
93
+ # Core CS
94
  ("CS2510", "CS3500", "Fundies 2 → OOD"),
95
  ("CS2510", "CS3000", "Fundies 2 → Algorithms"),
96
+ ("CS2800", "CS3000", "LogicAlgorithms"),
97
+
98
+ # --- THIS IS THE FIX ---
99
+ ("CS3000", "CS3650", "Algorithms -> Systems"),
100
+ # ---------------------
101
+
102
+ # Core AI/ML
103
+ ("CS3000", "CS4100", "Algorithms → AI"),
104
+ ("CS3500", "CS4100", "OOD → AI"),
105
+ # Core DS Path
106
  ("DS2000", "DS2500", "Prog w/ Data → Intermediate"),
107
  ("DS2500", "DS3500", "Intermediate → Advanced"),
108
+ ("DS3500", "DS4400", "Advanced → ML1"),
109
+ ("CS3500", "DS4400", "OOD → ML1"),
110
+ # Math
111
+ ("MATH1341", "MATH1342", "Calc 1 → Calc 2"),
112
  }
113
 
114
  added = 0
 
123
  print(" ✅ All critical chains present")
124
 
125
  def remove_spurious_chains(self):
126
+ """
127
+ --- FIX 3: REMOVE BAD DATA ---
128
+ Removes known incorrect prerequisite edges from scraper.
129
+ """
130
  print("\n🗑️ Removing spurious prerequisite chains...")
131
 
132
+ # Based on your inspect_graph output and catalog knowledge
133
  spurious_chains = {
134
+ ("CS2500", "CS2800"), # Fundies 1 is NOT a prereq for Logic
135
+ ("MATH1365", "CS2800"), # Not a real prereq
136
+ ("EECE2160", "CS3000"), # Irrelevant prereq
137
+ ("EECE2560", "CS3500"), # Irrelevant prereq
138
  }
139
 
140
  removed = 0
 
152
  print("\n🧮 Calculating complexity scores...")
153
 
154
  for node in self.graph.nodes():
155
+ # Use predecessors/successors on the *cleaned* graph
156
  in_degree = self.graph.in_degree(node)
157
  out_degree = self.graph.out_degree(node)
158
 
 
166
  """Check if all critical courses exist in the graph."""
167
  print("\n🎯 Validating critical course coverage...")
168
 
169
+ # This list MUST match the optimizer's requirements
170
  required_courses = {
171
  "foundations": {"CS1800", "CS2500", "CS2510", "CS2800"},
172
+ "core": {"CS3000", "CS3500", "CS3650", "CS3200", "CS5700"}, # Added CS5700
173
  "ai_ml": {"CS4100", "DS4400", "CS4120", "DS4420", "CS4180", "DS4440"},
174
+ "systems": {"CS4730", "CS4700", "CS4400", "CS4500"},
175
  "security": {"CY2550", "CY3740", "CY4740", "CY4760"},
176
+ "math": {"MATH1341", "MATH1342", "MATH2331", "MATH3081"},
177
  }
178
 
179
  missing = {}
 
202
  f.write(f"Total courses: {self.graph.number_of_nodes()}\n")
203
  f.write(f"Total prerequisites: {self.graph.number_of_edges()}\n\n")
204
 
 
 
205
  subject_counts = defaultdict(int)
206
  for node in self.graph.nodes():
207
  subject = self.graph.nodes[node].get('subject', 'UNKNOWN')
208
  subject_counts[subject] += 1
209
 
210
+ f.write("Subject breakdown (Filtered):\n")
211
  for subject in sorted(subject_counts.keys()):
212
  f.write(f" {subject}: {subject_counts[subject]}\n")
213
 
 
233
 
234
  if __name__ == "__main__":
235
  parser = argparse.ArgumentParser(description="NEU Curriculum Analyzer - Cleans and validates data")
236
+ parser.add_argument('--graph', required=True, help="Path to RAW curriculum graph (e.g., neu_merged_graph_...pkl)")
237
+ parser.add_argument('--courses', required=True, help="Path to RAW courses data (e.g., neu_merged_courses_...pkl)")
238
+ parser.add_argument('--output-graph', default='neu_graph_clean.pkl', help="Output path for the new, clean graph")
239
  args = parser.parse_args()
240
  main(args)
src/curriculum_optimizer.py CHANGED
@@ -1,513 +1,657 @@
1
- """
2
- Fixed Hybrid Curriculum Optimizer
3
- WITH PROPER COURSE DISCOVERY, SUBJECT-AWARE SCORING, AND CONCENTRATION FOCUS
4
- """
5
- import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
- from sentence_transformers import SentenceTransformer, util
8
- import networkx as nx
9
- import numpy as np
10
- from typing import Dict, List, Set, Tuple, Optional
11
- from dataclasses import dataclass
12
- import re
13
- import json
14
- import random
15
- from datetime import datetime
16
-
17
- @dataclass
18
- class StudentProfile:
19
- completed_courses: List[str]
20
- time_commitment: int
21
- preferred_difficulty: str
22
- career_goals: str
23
- interests: List[str]
24
- current_gpa: float = 3.5
25
- learning_style: str = "Visual"
26
-
27
- class HybridOptimizer:
28
- """
29
- Fixed optimizer with subject-aware scoring and concentration focus
30
- """
31
-
32
- EQUIVALENCY_GROUPS = [
33
- {"MATH1341", "MATH1241", "MATH1231"}, # Calculus 1
34
- {"MATH1342", "MATH1242"}, # Calculus 2
35
- {"PHYS1151", "PHYS1161", "PHYS1145"}, # Physics 1
36
- {"PHYS1155", "PHYS1165", "PHYS1147"}, # Physics 2
37
- ]
38
- COURSE_TRACKS = {
39
- "physics": {
40
- "engineering": ["PHYS1151", "PHYS1155"],
41
- "science": ["PHYS1161", "PHYS1165"],
42
- "life_sciences": ["PHYS1145", "PHYS1147"]
43
- },
44
- "calculus": {
45
- "standard": ["MATH1341", "MATH1342"],
46
- "computational": ["MATH156", "MATH256"]
47
- }
48
- }
49
-
50
- CONCENTRATION_REQUIREMENTS = {
51
- "ai_ml": {
52
- "foundations": {
53
- "required": ["CS1800", "CS2500", "CS2510", "CS2800"],
54
- "sequence": True
55
- },
56
- "core": {
57
- "required": ["CS3000", "CS3500"],
58
- "pick_1_from": ["CS3200", "CS3650", "CS5700"]
59
- },
60
- "concentration_specific": {
61
- "required": ["CS4100", "DS4400"],
62
- "pick_2_from": ["CS4120", "CS4180", "DS4420", "DS4440"],
63
- "pick_1_systems": ["CS4730", "CS4700"]
64
- },
65
- "math": {
66
- "required": ["MATH1341", "MATH1342"],
67
- "pick_1_from": ["MATH2331", "MATH3081"]
68
- }
69
- },
70
- "systems": {
71
- "foundations": { "required": ["CS1800", "CS2500", "CS2510", "CS2800"] },
72
- "core": { "required": ["CS3000", "CS3500", "CS3650"], "pick_1_from": ["CS5700", "CS3200"] },
73
- "concentration_specific": { "required": ["CS4700"], "pick_2_from": ["CS4730"], "pick_1_from": ["CS4400", "CS4500", "CS4520"] },
74
- "math": { "required": ["MATH1341", "MATH1342"] }
75
- },
76
- "security": {
77
- "foundations": { "required": ["CS1800", "CS2500", "CS2510", "CS2800"] },
78
- "core": { "required": ["CS3000", "CS3650", "CY2550"], "pick_1_from": ["CS5700", "CS3500"] },
79
- "concentration_specific": { "required": ["CY3740"], "pick_2_from": ["CY4740", "CY4760", "CY4770"], "pick_1_from": ["CS4700", "CS4730"] },
80
- "math": { "required": ["MATH1342"], "pick_1_from": ["MATH3527", "MATH3081"] }
81
- }
82
- }
83
-
84
- def __init__(self):
85
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
- self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
87
- self.embedding_model_name = 'BAAI/bge-large-en-v1.5'
88
- self.llm = None
89
- self.tokenizer = None
90
- self.embedding_model = None
91
- self.curriculum_graph = None
92
- self.courses = {}
93
- self.current_student = None
94
-
95
- def load_models(self):
96
- print("Loading embedding model...")
97
- self.embedding_model = SentenceTransformer(self.embedding_model_name, device=self.device)
98
-
99
- def load_llm(self):
100
- if self.device.type == 'cuda' and self.llm is None:
101
- print("Loading LLM for intelligent planning...")
102
- quant_config = BitsAndBytesConfig(
103
- load_in_4bit=True,
104
- bnb_4bit_quant_type="nf4",
105
- bnb_4bit_compute_dtype=torch.bfloat16
106
- )
107
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
108
- self.tokenizer.pad_token = self.tokenizer.eos_token
109
- self.llm = AutoModelForCausalLM.from_pretrained(
110
- self.model_name,
111
- quantization_config=quant_config,
112
- device_map="auto"
113
- )
114
-
115
- def load_data(self, graph: nx.DiGraph):
116
- self.curriculum_graph = graph
117
- self.courses = dict(graph.nodes(data=True))
118
- UNDERGRAD_ACCESSIBLE_GRAD = {"CS5700", "CY5700", "DS5110", "CS5010"}
119
- self.valid_courses = []
120
- course_texts = []
121
-
122
- concentration_courses = set()
123
- for track_reqs in self.CONCENTRATION_REQUIREMENTS.values():
124
- for category, reqs in track_reqs.items():
125
- if isinstance(reqs, dict):
126
- for key, courses in reqs.items():
127
- if isinstance(courses, list):
128
- concentration_courses.update(courses)
129
-
130
- for cid, data in self.courses.items():
131
- name = data.get('name', '')
132
- if not name or name.strip() == '' or any(skip in name.lower() for skip in ['lab', 'recitation', 'seminar', 'practicum']):
133
- continue
134
-
135
- course_level = self._get_level(cid)
136
- if course_level >= 5000 and cid not in UNDERGRAD_ACCESSIBLE_GRAD:
137
- continue
138
-
139
- self.valid_courses.append(cid)
140
- course_texts.append(f"{name} {data.get('description', '')}")
141
-
142
- missing_required = concentration_courses - set(self.valid_courses)
143
- if missing_required:
144
- print(f"\n⚠️ WARNING: {len(missing_required)} required courses missing from graph: {sorted(missing_required)}\n")
145
-
146
- print(f"Computing embeddings for {len(self.valid_courses)} courses...")
147
- self.course_embeddings = self.embedding_model.encode(course_texts, convert_to_tensor=True, show_progress_bar=True)
148
- print(f"\nTotal valid courses: {len(self.valid_courses)}")
149
-
150
- def _get_level(self, course_id: str) -> int:
151
- match = re.search(r'\d+', course_id)
152
- return int(match.group()) if match else 9999
153
-
154
- def _get_completed_with_equivalents(self, completed: Set[str]) -> Set[str]:
155
- expanded_completed = completed.copy()
156
- for course in completed:
157
- for group in self.EQUIVALENCY_GROUPS:
158
- if course in group:
159
- expanded_completed.update(group)
160
- return expanded_completed
161
-
162
- def _can_take_course(self, course_id: str, completed: Set[str]) -> bool:
163
- effective_completed = self._get_completed_with_equivalents(completed)
164
- if course_id not in self.curriculum_graph:
165
- return True
166
- prereqs = set(self.curriculum_graph.predecessors(course_id))
167
- return prereqs.issubset(effective_completed)
168
-
169
- def _validate_sequence(self, selected: List[str], candidate: str) -> bool:
170
- for track_type, tracks in self.COURSE_TRACKS.items():
171
- for track_name, sequence in tracks.items():
172
- if candidate in sequence:
173
- for other_track, other_seq in tracks.items():
174
- if other_track != track_name and any(c in selected for c in other_seq):
175
- return False
176
- return True
177
-
178
- def _score_course(self, course_id: str, semantic_scores: Dict[str, float], required_set: Set[str], picklist_set: Set[str]) -> float:
179
- """FIXED: Proper scoring with IS heavy penalty"""
180
-
181
- if course_id not in self.courses or not self.courses[course_id].get('name', '').strip():
182
- return -10000.0
183
-
184
- course_data = self.courses[course_id]
185
- subject = course_data.get('subject', '')
186
-
187
- score = 0.0
188
-
189
- # Subject bonuses/penalties
190
- if subject in ["CS", "DS", "CY"]:
191
- score += 300.0
192
- elif subject == "MATH":
193
- score += 100.0
194
- else:
195
- score -= 1000.0 # Heavy penalty for everything else (including IS)
196
-
197
- # Required courses: massive boost
198
- if course_id in required_set:
199
- score += 10000.0 # INCREASED from 1000
200
-
201
- # Pick-list courses: high boost
202
- if course_id in picklist_set:
203
- score += 5000.0 # INCREASED from 500
204
-
205
- # Unlocking factor (reduced weight)
206
- if course_id in self.curriculum_graph:
207
- unlocks = self.curriculum_graph.out_degree(course_id)
208
- score += min(unlocks, 5) * 2.0 # REDUCED
209
-
210
- # Level preference
211
- level = self._get_level(course_id)
212
- score -= (level / 100.0)
213
-
214
- # Semantic alignment (reduced weight)
215
- score += semantic_scores.get(course_id, 0.0) * 5.0 # REDUCED from 15
216
-
217
- return score
218
-
219
- def generate_simple_plan(self, student: StudentProfile) -> Dict:
220
- print("--- Generating Enhanced Rule-Based Plan ---")
221
- self.current_student = student
222
- return self.generate_enhanced_rule_plan(student)
223
-
224
- def generate_enhanced_rule_plan(self, student: StudentProfile) -> Dict:
225
- self.current_student = student
226
- track = self._identify_track(student)
227
- plan = self._build_structured_plan(student, track, None)
228
- validation = self.validate_plan(plan, student)
229
-
230
- if validation["errors"]:
231
- plan = self._fix_plan_errors(plan, validation, student)
232
- validation = self.validate_plan(plan, student)
233
-
234
- difficulty_level = self._map_difficulty(student.preferred_difficulty)
235
- courses_per_semester = self._calculate_course_load(student.time_commitment)
236
- explanation = f"Personalized {track} track ({difficulty_level} difficulty, {courses_per_semester} courses/semester)"
237
-
238
- return self._finalize_plan(plan, explanation, validation)
239
-
240
- def generate_llm_plan(self, student: StudentProfile) -> Dict:
241
- print("--- Generating AI-Optimized Plan ---")
242
- self.current_student = student
243
- self.load_llm()
244
- if not self.llm:
245
- return self.generate_enhanced_rule_plan(student)
246
-
247
- track = self._identify_track(student)
248
- llm_suggestions = self._get_llm_course_suggestions(student, track)
249
- plan = self._build_structured_plan(student, track, llm_suggestions)
250
- validation = self.validate_plan(plan, student)
251
- if validation["errors"]:
252
- plan = self._fix_plan_errors(plan, validation, student)
253
- validation = self.validate_plan(plan, student)
254
-
255
- explanation = self._generate_explanation(student, plan, track, "AI-optimized")
256
- return self._finalize_plan(plan, explanation, validation)
257
-
258
- def _build_structured_plan(self, student: StudentProfile, track: str, llm_suggestions: Optional[List[str]] = None) -> Dict:
259
- """FIXED with hardcoded Year 2 priorities"""
260
-
261
- completed = set(student.completed_courses)
262
- plan = {}
263
- requirements = self.CONCENTRATION_REQUIREMENTS.get(track, self.CONCENTRATION_REQUIREMENTS["ai_ml"])
264
-
265
- courses_per_semester = self._calculate_course_load(student.time_commitment)
266
-
267
- # Build required and pick sets
268
- required_set = set()
269
- picklist_set = set()
270
- for category, reqs in requirements.items():
271
- if "required" in reqs:
272
- required_set.update(reqs["required"])
273
- for key, courses in reqs.items():
274
- if key.startswith("pick_"):
275
- picklist_set.update(courses)
276
-
277
- semantic_scores = self._compute_semantic_scores(student)
278
-
279
- # HARDCODED FIX: Force Year 2 to prioritize core courses
280
- YEAR2_MUST_TAKE = ["CS3000", "CS3500", "DS2500", "MATH2331", "MATH3081"]
281
-
282
- for sem_num in range(1, 9):
283
- year = ((sem_num - 1) // 2) + 1
284
-
285
- available_courses = self._get_available_courses(completed, year, sem_num, track)
286
-
287
- # Filter: must be takeable
288
- schedulable = [
289
- c for c in available_courses
290
- if c not in completed and self._can_take_course(c, completed)
291
- ]
292
-
293
- # HARDCODED: In Year 2, force core courses to the top
294
- if year == 2:
295
- priority_courses = [c for c in YEAR2_MUST_TAKE if c in schedulable]
296
- other_courses = [c for c in schedulable if c not in YEAR2_MUST_TAKE]
297
-
298
- # Score priority courses separately
299
- scored_priority = sorted(
300
- priority_courses,
301
- key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set),
302
- reverse=True
303
- )
304
- scored_others = sorted(
305
- other_courses,
306
- key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set),
307
- reverse=True
308
- )
309
-
310
- scored_courses = scored_priority + scored_others
311
- else:
312
- # Normal scoring for other years
313
- scored_courses = sorted(
314
- schedulable,
315
- key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set),
316
- reverse=True
317
- )
318
-
319
- # Select top N courses
320
- selected = []
321
- for course in scored_courses:
322
- if len(selected) >= courses_per_semester:
323
- break
324
- if self._validate_sequence(selected, course):
325
- selected.append(course)
326
-
327
- # Add to plan
328
- if selected:
329
- year_key = f"year_{year}"
330
- if year_key not in plan:
331
- plan[year_key] = {}
332
-
333
- sem_type = 'fall' if (sem_num % 2) == 1 else 'spring'
334
- plan[year_key][sem_type] = selected
335
- completed.update(selected)
336
-
337
- return plan
338
-
339
- def _get_available_courses(self, completed: Set[str], year: int, sem_num: int = None, track: str = "ai_ml") -> List[str]:
340
- """FIXED: Return ALL courses that COULD be taken in this year"""
341
-
342
- # Year 1: Hardcoded foundation
343
- if year == 1:
344
- if not completed or len(completed) < 2:
345
- return [c for c in ["CS1800", "CS2500", "MATH1341", "ENGW1111"] if c in self.valid_courses]
346
- else:
347
- next_courses = []
348
- for course, prereq in [("CS2800", "CS1800"), ("CS2510", "CS2500"), ("MATH1342", "MATH1341"), ("DS2000", None)]:
349
- if course in self.valid_courses and course not in completed:
350
- if prereq is None or prereq in completed:
351
- next_courses.append(course)
352
- return next_courses
353
-
354
- # Years 2-4: Filter by subject and level
355
- available = []
356
-
357
- # ONLY CS/DS/CY/MATH allowed
358
- ALLOWED_SUBJECTS = {"CS", "DS", "CY", "MATH"}
359
-
360
- for cid in self.valid_courses:
361
- if cid in completed:
362
- continue
363
-
364
- course_data = self.courses.get(cid, {})
365
- subject = course_data.get('subject')
366
-
367
- if subject not in ALLOWED_SUBJECTS:
368
- continue
369
-
370
- course_level = self._get_level(cid)
371
-
372
- # Year-based level filtering
373
- if year == 2 and course_level > 3999:
374
- continue # No 4000+ in Year 2
375
- if year >= 3 and course_level < 2000:
376
- continue # No intro courses in Years 3-4
377
-
378
- available.append(cid)
379
-
380
- return available
381
-
382
- def _fix_plan_errors(self, plan: Dict, validation: Dict, student: StudentProfile) -> Dict:
383
- if any("Mixed" in error for error in validation["errors"]):
384
- return self._build_structured_plan(student, self._identify_track(student), None)
385
- return plan
386
-
387
- def _get_llm_course_suggestions(self, student: StudentProfile, track: str) -> List[str]:
388
- requirements = self.CONCENTRATION_REQUIREMENTS.get(track, {})
389
- all_options = set()
390
- for reqs in requirements.values():
391
- for key, courses in reqs.items():
392
- if key.startswith("pick_"): all_options.update(courses)
393
-
394
- course_options_text = [f"{cid}: {self.courses[cid].get('name', cid)} - {self.courses[cid].get('description', '')[:100].strip()}"
395
- for cid in list(all_options)[:15] if cid in self.courses]
396
-
397
- prompt = f"""You are an expert curriculum advisor. Based on the student profile, rank the top 5 most relevant courses from the list below.
398
- ### Student Profile:
399
- - **Career Goal:** {student.career_goals}
400
- - **Interests:** {', '.join(student.interests)}
401
- - **Preferred Difficulty:** {student.preferred_difficulty}
402
- ### Available Elective Courses:
403
- {chr(10).join(course_options_text)}
404
- Return ONLY the top 5 course IDs, each on a new line.
405
- """
406
- try:
407
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(self.device)
408
- with torch.no_grad():
409
- outputs = self.llm.generate(**inputs, max_new_tokens=100, temperature=0.2, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
410
- response = self.tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
411
- suggested_courses = re.findall(r'([A-Z]{2,4}\d{4})', response)
412
- return suggested_courses[:5]
413
- except Exception as e:
414
- print(f"LLM suggestion failed: {e}")
415
- return list(all_options)[:5]
416
-
417
- def _map_difficulty(self, preferred_difficulty: str) -> str:
418
- return {"easy": "easy", "moderate": "medium", "challenging": "hard"}.get(preferred_difficulty.lower(), "medium")
419
-
420
- def _calculate_course_load(self, time_commitment: int) -> int:
421
- if time_commitment <= 20: return 3
422
- if time_commitment <= 40: return 4 # Setting hours to 40 will now correctly return 4.
423
- return 5
424
-
425
- def _identify_track(self, student: StudentProfile) -> str:
426
- if not hasattr(self, 'embedding_model') or self.embedding_model is None:
427
- combined = f"{student.career_goals.lower()} {' '.join(student.interests).lower()}"
428
- if any(word in combined for word in ['ai', 'ml', 'machine learning', 'data']): return "ai_ml"
429
- if any(word in combined for word in ['systems', 'distributed', 'backend']): return "systems"
430
- if any(word in combined for word in ['security', 'cyber']): return "security"
431
- return "ai_ml"
432
- profile_text = f"{student.career_goals} {' '.join(student.interests)}"
433
- profile_emb = self.embedding_model.encode(profile_text, convert_to_tensor=True)
434
- track_descriptions = {
435
- "ai_ml": "artificial intelligence machine learning deep learning neural networks data science",
436
- "systems": "operating systems distributed systems networks compilers databases performance backend",
437
- "security": "cybersecurity cryptography network security ethical hacking vulnerabilities"
438
- }
439
- best_track, best_score = "ai_ml", -1.0
440
- for track, description in track_descriptions.items():
441
- track_emb = self.embedding_model.encode(description, convert_to_tensor=True)
442
- score = float(util.cos_sim(profile_emb, track_emb))
443
- if score > best_score:
444
- best_score, best_track = score, track
445
- return best_track
446
-
447
- def _compute_semantic_scores(self, student: StudentProfile) -> Dict[str, float]:
448
- query_text = f"{student.career_goals} {' '.join(student.interests)}"
449
- query_emb = self.embedding_model.encode(query_text, convert_to_tensor=True)
450
- similarities = util.cos_sim(query_emb, self.course_embeddings)[0]
451
- return {cid: float(similarities[idx]) for idx, cid in enumerate(self.valid_courses)}
452
-
453
- def _generate_explanation(self, student: StudentProfile, plan: Dict, track: str, plan_type: str) -> str:
454
- return f"{plan_type.title()} plan for the {track} track, tailored to your goal of becoming a {student.career_goals}."
455
-
456
- def validate_plan(self, plan: Dict, student: StudentProfile = None) -> Dict[str, List[str]]:
457
- issues = {"errors": [], "warnings": [], "info": []}
458
- all_courses = [course for year in plan.values() for sem in year.values() for course in sem if isinstance(sem, list)]
459
-
460
- for track_type, tracks in self.COURSE_TRACKS.items():
461
- tracks_used = {name for name, courses in tracks.items() if any(c in all_courses for c in courses)}
462
- if len(tracks_used) > 1:
463
- issues["errors"].append(f"Mixed {track_type} tracks: {', '.join(tracks_used)}. Choose one sequence.")
464
-
465
- completed_for_validation = set(student.completed_courses) if student else set()
466
- for year in range(1, 5):
467
- for sem in ["fall", "spring"]:
468
- year_key = f"year_{year}"
469
- sem_courses = plan.get(year_key, {}).get(sem, [])
470
- for course in sem_courses:
471
- if course in self.curriculum_graph:
472
- prereqs = set(self.curriculum_graph.predecessors(course))
473
- if not prereqs.issubset(self._get_completed_with_equivalents(completed_for_validation)):
474
- missing = prereqs - completed_for_validation
475
- issues["errors"].append(f"{course} in Year {year} {sem} is missing prereqs: {', '.join(missing)}")
476
- completed_for_validation.update(sem_courses)
477
- return issues
478
-
479
- def _finalize_plan(self, plan: Dict, explanation: str, validation: Dict = None) -> Dict:
480
- structured_plan = {"reasoning": explanation, "validation": validation or {"errors": [], "warnings": [], "info": []}}
481
- complexities = []
482
- for year in range(1, 5):
483
- year_key = f"year_{year}"
484
- structured_plan[year_key] = {
485
- "fall": plan.get(year_key, {}).get("fall", []),
486
- "spring": plan.get(year_key, {}).get("spring", []),
487
- "summer": "co-op" if year in [2, 3] else []
488
- }
489
- for sem in ["fall", "spring"]:
490
- courses = structured_plan[year_key][sem]
491
- if courses:
492
- sem_complexity = sum(self.courses.get(c, {}).get('complexity', 50) for c in courses)
493
- complexities.append(sem_complexity)
494
-
495
- structured_plan["complexity_analysis"] = {
496
- "average_semester_complexity": float(np.mean(complexities)) if complexities else 0,
497
- "peak_semester_complexity": float(np.max(complexities)) if complexities else 0,
498
- "total_complexity": float(np.sum(complexities)) if complexities else 0,
499
- "balance_score (std_dev)": float(np.std(complexities)) if complexities else 0
500
- }
501
- structured_plan["metadata"] = {
502
- "generated": datetime.now().isoformat(),
503
- "valid": len(validation.get("errors", [])) == 0 if validation else True,
504
- }
505
- return {"pathway": structured_plan}
506
-
507
- class CurriculumOptimizer(HybridOptimizer):
508
- """Wrapper to maintain compatibility with older script calls."""
509
- def __init__(self):
510
- super().__init__()
511
-
512
- def generate_plan(self, student: StudentProfile) -> Dict:
513
- return self.generate_enhanced_rule_plan(student)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Curriculum Optimizer - PRODUCTION VERSION
3
+ All redundant code removed, all critical issues fixed
4
+ """
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
7
+ from sentence_transformers import SentenceTransformer, util
8
+ import networkx as nx
9
+ import numpy as np
10
+ from typing import Dict, List, Set, Optional
11
+ from dataclasses import dataclass
12
+ import re
13
+ from datetime import datetime
14
+
15
+ @dataclass
16
+ class StudentProfile:
17
+ completed_courses: List[str]
18
+ time_commitment: int
19
+ preferred_difficulty: str
20
+ career_goals: str
21
+ interests: List[str]
22
+ current_gpa: float = 3.5
23
+ learning_style: str = "Visual"
24
+
25
+ class HybridOptimizer:
26
+
27
+ EQUIVALENCY_GROUPS = [
28
+ {"MATH1341", "MATH1241", "MATH1231"},
29
+ {"MATH1342", "MATH1242"},
30
+ {"PHYS1151", "PHYS1161", "PHYS1145"},
31
+ {"PHYS1155", "PHYS1165", "PHYS1147"},
32
+ ]
33
+
34
+ COURSE_TRACKS = {
35
+ "physics": {
36
+ "engineering": ["PHYS1151", "PHYS1155"],
37
+ "science": ["PHYS1161", "PHYS1165"],
38
+ "life_sciences": ["PHYS1145", "PHYS1147"]
39
+ },
40
+ "calculus": {
41
+ "standard": ["MATH1341", "MATH1342"],
42
+ "computational": ["MATH156", "MATH256"]
43
+ }
44
+ }
45
+
46
+ CONCENTRATION_REQUIREMENTS = {
47
+ "ai_ml": {
48
+ "foundations": {
49
+ "required": ["CS1800", "CS2500", "CS2510", "CS2800"],
50
+ "sequence": True
51
+ },
52
+ "core": {
53
+ "required": ["CS3000", "CS3500"],
54
+ "pick_1_from": ["CS3200", "CS3650", "CS5700"]
55
+ },
56
+ "concentration_specific": {
57
+ "required": ["CS4100", "DS4400"],
58
+ "pick_2_from": ["CS4120", "CS4180", "DS4420", "DS4440"],
59
+ "pick_1_systems": ["CS4730", "CS4700"]
60
+ },
61
+ "math": {
62
+ "required": ["MATH1341", "MATH1342"],
63
+ "pick_1_from": ["MATH2331", "MATH3081"]
64
+ }
65
+ },
66
+ "systems": {
67
+ "foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
68
+ "core": {"required": ["CS3000", "CS3500", "CS3650"], "pick_1_from": ["CS5700", "CS3200"]},
69
+ "concentration_specific": {"required": ["CS4700"], "pick_2_from": ["CS4730"], "pick_1_from": ["CS4400", "CS4500", "CS4520"]},
70
+ "math": {"required": ["MATH1341", "MATH1342"]}
71
+ },
72
+ "security": {
73
+ "foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
74
+ "core": {"required": ["CS3000", "CS3650", "CY2550"], "pick_1_from": ["CS5700", "CS3500"]},
75
+ "concentration_specific": {"required": ["CY3740"], "pick_2_from": ["CY4740", "CY4760", "CY4770"], "pick_1_from": ["CS4700", "CS4730"]},
76
+ "math": {"required": ["MATH1342"], "pick_1_from": ["MATH3527", "MATH3081"]}
77
+ }
78
+ }
79
+
80
+ def __init__(self):
81
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+ self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
83
+ self.embedding_model_name = 'BAAI/bge-large-en-v1.5'
84
+ self.llm = None
85
+ self.tokenizer = None
86
+ self.embedding_model = None
87
+ self.curriculum_graph = None
88
+ self.courses = {}
89
+ self.current_student = None
90
+
91
+ def load_models(self):
92
+ print("Loading embedding model...")
93
+ self.embedding_model = SentenceTransformer(self.embedding_model_name, device=self.device)
94
+
95
+ def load_llm(self):
96
+ if self.device.type == 'cuda' and self.llm is None:
97
+ print("Loading LLM for intelligent planning...")
98
+ quant_config = BitsAndBytesConfig(
99
+ load_in_4bit=True,
100
+ bnb_4bit_quant_type="nf4",
101
+ bnb_4bit_compute_dtype=torch.bfloat16
102
+ )
103
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
104
+ self.tokenizer.pad_token = self.tokenizer.eos_token
105
+ self.llm = AutoModelForCausalLM.from_pretrained(
106
+ self.model_name,
107
+ quantization_config=quant_config,
108
+ device_map="auto"
109
+ )
110
+
111
+ def load_data(self, graph: nx.DiGraph):
112
+ self.curriculum_graph = graph
113
+ self.courses = dict(graph.nodes(data=True))
114
+ UNDERGRAD_ACCESSIBLE_GRAD = {"CS5700", "CY5700", "DS5110", "CS5010"}
115
+ self.valid_courses = []
116
+ course_texts = []
117
+
118
+ concentration_courses = set()
119
+ for track_reqs in self.CONCENTRATION_REQUIREMENTS.values():
120
+ for category, reqs in track_reqs.items():
121
+ if isinstance(reqs, dict):
122
+ for key, courses in reqs.items():
123
+ if isinstance(courses, list):
124
+ concentration_courses.update(courses)
125
+
126
+ for cid, data in self.courses.items():
127
+ name = data.get('name', '')
128
+ if not name or name.strip() == '' or any(skip in name.lower() for skip in ['lab', 'recitation', 'seminar', 'practicum']):
129
+ continue
130
+
131
+ course_level = self._get_level(cid)
132
+ if course_level >= 5000 and cid not in UNDERGRAD_ACCESSIBLE_GRAD:
133
+ continue
134
+
135
+ self.valid_courses.append(cid)
136
+ course_texts.append(f"{name} {data.get('description', '')}")
137
+
138
+ missing_required = concentration_courses - set(self.valid_courses)
139
+ if missing_required:
140
+ print(f"\n⚠️ WARNING: {len(missing_required)} required courses missing from graph: {sorted(missing_required)}\n")
141
+
142
+ print(f"Computing embeddings for {len(self.valid_courses)} courses...")
143
+ self.course_embeddings = self.embedding_model.encode(course_texts, convert_to_tensor=True, show_progress_bar=True)
144
+ print(f"\nTotal valid courses: {len(self.valid_courses)}")
145
+
146
+ def _get_level(self, course_id: str) -> int:
147
+ match = re.search(r'\d+', course_id)
148
+ return int(match.group()) if match else 9999
149
+
150
+ def _get_completed_with_equivalents(self, completed: Set[str]) -> Set[str]:
151
+ expanded_completed = completed.copy()
152
+ for course in completed:
153
+ for group in self.EQUIVALENCY_GROUPS:
154
+ if course in group:
155
+ expanded_completed.update(group)
156
+ return expanded_completed
157
+
158
+ def _can_take_course(self, course_id: str, completed: Set[str]) -> bool:
159
+ effective_completed = self._get_completed_with_equivalents(completed)
160
+ if course_id not in self.curriculum_graph:
161
+ return True
162
+ prereqs = set(self.curriculum_graph.predecessors(course_id))
163
+ return prereqs.issubset(effective_completed)
164
+
165
+ def _validate_sequence(self, selected: List[str], candidate: str) -> bool:
166
+ for track_type, tracks in self.COURSE_TRACKS.items():
167
+ for track_name, sequence in tracks.items():
168
+ if candidate in sequence:
169
+ for other_track, other_seq in tracks.items():
170
+ if other_track != track_name and any(c in selected for c in other_seq):
171
+ return False
172
+ return True
173
+
174
+ def _score_course(self, course_id: str, semantic_scores: Dict[str, float], required_set: Set[str], picklist_set: Set[str], year: int, track: str) -> float:
175
+ """
176
+ PRODUCTION SCORING - NOW TRACK AWARE
177
+ Applies different boosts based on the selected track.
178
+ """
179
+ if course_id not in self.courses or not self.courses[course_id].get('name', '').strip():
180
+ return -10000.0
181
+
182
+ course_data = self.courses[course_id]
183
+ subject = course_data.get('subject', '')
184
+ level = self._get_level(course_id)
185
+ name = course_data.get('name', '').lower()
186
+
187
+ score = 0.0
188
+
189
+ # --- SEMANTICS APPLIED FIRST ---
190
+ semantic_weight = 15.0 if year == 4 else 5.0
191
+ score += semantic_scores.get(course_id, 0.0) * semantic_weight
192
+
193
+ # --- PENALTY APPLIED AFTER SEMANTICS ---
194
+ non_technical_keywords = ['society', 'ethics', 'law', 'policy', 'mobile', 'game', 'visualiz', 'web']
195
+ if any(keyword in name for keyword in non_technical_keywords):
196
+ # Exception: allow 'game' and 'mobile' if game_dev track is selected
197
+ if track == "game_dev" and any(k in name for k in ['game', 'mobile']):
198
+ pass # Do not penalize
199
+ else:
200
+ score -= 10000.0
201
+
202
+ # Subject-aware scoring
203
+ if subject in ["CS", "DS"]:
204
+ score += 300.0
205
+ elif subject == "CY":
206
+ if level < 3000:
207
+ score -= 500.0
208
+ else:
209
+ score += 300.0 # Allow CY electives if not intro
210
+ elif subject == "MATH":
211
+ score += 100.0
212
+ else:
213
+ score -= 1000.0
214
+
215
+ # --- TRACK-AWARE CRITICAL PATH BOOSTS ---
216
+ if track == "ai_ml":
217
+ if course_id in ["DS2500", "DS3000", "DS3500"]:
218
+ score += 7000.0
219
+ elif track == "security":
220
+ if course_id in ["CY2550", "CY3740"]:
221
+ score += 7000.0
222
+ elif track == "systems":
223
+ if course_id == "CS3650":
224
+ score += 7000.0
225
+ elif track == "game_dev":
226
+ if course_id == "CS3540": # Game Programming
227
+ score += 8000.0 # Main course for this track
228
+ # "general" track gets no special boosts
229
+
230
+ # Hard requirements
231
+ if course_id in required_set:
232
+ score += 10000.0
233
+
234
+ # Pick-list courses
235
+ if course_id in picklist_set:
236
+ score += 5000.0
237
+
238
+ # Unlocking factor
239
+ if course_id in self.curriculum_graph:
240
+ unlocks = self.curriculum_graph.out_degree(course_id)
241
+ score += min(unlocks, 5) * 2.0
242
+
243
+ # Level preference
244
+ score -= (level / 100.0)
245
+
246
+ # Year-specific penalties
247
+ if year == 4 and level < 4000:
248
+ score -= 3000.0
249
+ elif year == 3 and level < 3000:
250
+ score -= 2000.0
251
+
252
+ return score
253
+
254
+ def generate_simple_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
255
+ print("--- Generating Enhanced Rule-Based Plan ---")
256
+ self.current_student = student
257
+ return self.generate_enhanced_rule_plan(student, track_override)
258
+
259
+ def generate_enhanced_rule_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
260
+ self.current_student = student
261
+
262
+ # --- FIX: Logic corrected to respect "general" override ---
263
+ if track_override:
264
+ track = track_override
265
+ print(f"--- Using user-selected track: {track} ---")
266
+ else:
267
+ track = self._identify_track(student)
268
+ print(f"--- Auto-identified track: {track} ---")
269
+ if not track:
270
+ track = "general"
271
+
272
+ plan = self._build_structured_plan(student, track, None)
273
+ validation = self.validate_plan(plan, student)
274
+
275
+ if validation["errors"]:
276
+ plan = self._fix_plan_errors(plan, validation, student)
277
+ validation = self.validate_plan(plan, student)
278
+
279
+ difficulty_level = self._map_difficulty(student.preferred_difficulty)
280
+ courses_per_semester = self._calculate_course_load(student.time_commitment)
281
+
282
+ track_name = track.replace("_", " ").title()
283
+ explanation = f"Personalized {track_name} track ({difficulty_level} difficulty, {courses_per_semester} courses/semester)"
284
+
285
+ return self._finalize_plan(plan, explanation, validation)
286
+
287
+ def generate_llm_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
288
+ print("--- Generating AI-Optimized Plan ---")
289
+ self.current_student = student
290
+ self.load_llm()
291
+ if not self.llm:
292
+ return self.generate_enhanced_rule_plan(student, track_override) # Pass override
293
+
294
+ # --- FIX: Use override if provided, otherwise identify ---
295
+ if track_override and track_override != "general":
296
+ track = track_override
297
+ print(f"--- Using user-selected track: {track} ---")
298
+ else:
299
+ track = self._identify_track(student)
300
+ print(f"--- Auto-identified track: {track} ---")
301
+ if not track:
302
+ track = "general"
303
+
304
+ llm_suggestions = self._get_llm_course_suggestions(student, track)
305
+ plan = self._build_structured_plan(student, track, llm_suggestions)
306
+ validation = self.validate_plan(plan, student)
307
+ if validation["errors"]:
308
+ plan = self._fix_plan_errors(plan, validation, student)
309
+ validation = self.validate_plan(plan, student)
310
+
311
+ track_name = track.replace("_", " ").title()
312
+ explanation = self._generate_explanation(student, plan, track, f"AI-optimized {track_name}")
313
+ return self._finalize_plan(plan, explanation, validation)
314
+
315
+
316
+
317
+ def _build_structured_plan(self, student: StudentProfile, track: str, llm_suggestions: Optional[List[str]] = None) -> Dict:
318
+ """
319
+ PRODUCTION PLANNER - NOW FULLY TRACK-AWARE
320
+ Uses different priority lists based on the selected track.
321
+ """
322
+ completed = set(student.completed_courses)
323
+ plan = {}
324
+
325
+ # --- FIX: TRACK-AWARE REQUIREMENTS ---
326
+ if track == "general":
327
+ print("--- Using General CS requirements ---")
328
+ requirements = {
329
+ "foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
330
+ "core": {"required": ["CS3000", "CS3500", "CS3650"]},
331
+ "math": {"required": ["MATH1341", "MATH1342"], "pick_1_from": ["MATH2331", "MATH3081"]}
332
+ }
333
+ elif track == "game_dev":
334
+ print("--- Using Game Dev (AI/ML base) requirements ---")
335
+ # Use ai_ml as a base, scoring/priorities will handle the rest
336
+ requirements = self.CONCENTRATION_REQUIREMENTS["ai_ml"]
337
+ else:
338
+ requirements = self.CONCENTRATION_REQUIREMENTS.get(track, self.CONCENTRATION_REQUIREMENTS["ai_ml"])
339
+
340
+ courses_per_semester = self._calculate_course_load(student.time_commitment)
341
+
342
+ # Build required and pick sets
343
+ required_set = set()
344
+ picklist_set = set()
345
+ for category, reqs in requirements.items():
346
+ if "required" in reqs:
347
+ required_set.update(reqs["required"])
348
+ for key, courses in reqs.items():
349
+ if key.startswith("pick_"):
350
+ picklist_set.update(courses)
351
+
352
+ semantic_scores = self._compute_semantic_scores(student)
353
+
354
+ # --- FIX: TRACK-AWARE PRIORITIES ---
355
+ TRACK_YEAR_PRIORITIES = {
356
+ "general": {
357
+ 2: ["CS3000", "CS3500", "CS3650", "MATH2331", "MATH3081", "CS3200"],
358
+ 3: ["CS4700", "CS4400", "CS4500", "CS4100"],
359
+ 4: ["CS5700", "CS4730", "CS4530", "CS4550", "CS4410"]
360
+ },
361
+ "ai_ml": {
362
+ 2: ["CS3000", "CS3500", "DS2500", "DS3000", "DS3500", "MATH2331", "MATH3081", "CS3650"],
363
+ 3: ["CS4100", "DS4400", "CS4120", "DS4420", "DS4440", "CS4180"],
364
+ 4: ["CS4730", "CS4700", "CS5700", "DS4300", "CS4400", "CS4500"]
365
+ },
366
+ "security": {
367
+ 2: ["CS3000", "CS3650", "CY2550", "MATH2331", "MATH3081", "CS3500"],
368
+ 3: ["CY3740", "CS4700", "CS5700", "CS4730"],
369
+ 4: ["CY4740", "CY4760", "CS4400"] # CY4770 is missing from graph
370
+ },
371
+ "systems": {
372
+ 2: ["CS3000", "CS3500", "CS3650", "MATH2331", "CS3200"],
373
+ 3: ["CS4700", "CS5700", "CS4730", "CS4500", "CS4400"],
374
+ 4: ["CS4520", "CS4410"]
375
+ },
376
+ "game_dev": {
377
+ 2: ["CS3000", "CS3500", "CS3540", "MATH2331", "MATH3081", "CS3650"],
378
+ 3: ["CS4520", "CS4300", "CS4100", "CS4700"],
379
+ 4: ["CS4550", "CS4410", "CS4180"]
380
+ }
381
+ }
382
+
383
+ for sem_num in range(1, 9):
384
+ year = ((sem_num - 1) // 2) + 1
385
+
386
+ available_courses = self._get_available_courses(completed, year, sem_num, track)
387
+
388
+ schedulable = [
389
+ c for c in available_courses
390
+ if c not in completed and self._can_take_course(c, completed)
391
+ ]
392
+
393
+ # Use track-specific priorities, default to "general" if track is unknown
394
+ current_year_priorities = TRACK_YEAR_PRIORITIES.get(track, TRACK_YEAR_PRIORITIES["general"]).get(year)
395
+
396
+ if current_year_priorities:
397
+ priority_courses = [c for c in current_year_priorities if c in schedulable]
398
+ other_courses = [c for c in schedulable if c not in current_year_priorities]
399
+
400
+ scored_priority = sorted(
401
+ priority_courses,
402
+ # --- FIX: Pass 'track' to score_course ---
403
+ key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set, year, track),
404
+ reverse=True
405
+ )
406
+ scored_others = sorted(
407
+ other_courses,
408
+ key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set, year, track),
409
+ reverse=True
410
+ )
411
+
412
+ scored_courses = scored_priority + scored_others
413
+ else:
414
+ # Year 1: normal scoring
415
+ scored_courses = sorted(
416
+ schedulable,
417
+ key=lambda c: self._score_course(c, semantic_scores, required_set, picklist_set, year, track),
418
+ reverse=True
419
+ )
420
+
421
+ # Select top N courses
422
+ selected = []
423
+ for course in scored_courses:
424
+ if len(selected) >= courses_per_semester:
425
+ break
426
+ if self._validate_sequence(selected, course):
427
+ selected.append(course)
428
+
429
+ if selected:
430
+ year_key = f"year_{year}"
431
+ if year_key not in plan:
432
+ plan[year_key] = {}
433
+
434
+ sem_type = 'fall' if (sem_num % 2) == 1 else 'spring'
435
+ plan[year_key][sem_type] = selected
436
+ completed.update(selected)
437
+
438
+ return plan
439
+
440
+ def _get_available_courses(self, completed: Set[str], year: int, sem_num: int = None, track: str = "ai_ml") -> List[str]:
441
+ """
442
+ PRODUCTION COURSE FILTER - Strict level enforcement
443
+ """
444
+ # Year 1: Hardcoded foundation
445
+ if year == 1:
446
+ if not completed or len(completed) < 2:
447
+ return [c for c in ["CS1800", "CS2500", "MATH1341", "ENGW1111"] if c in self.valid_courses]
448
+ else:
449
+ next_courses = []
450
+ prereq_map = [
451
+ ("CS2800", "CS1800"),
452
+ ("CS2510", "CS2500"),
453
+ ("MATH1342", "MATH1341"),
454
+ ("DS2000", None),
455
+ ("DS2500", "DS2000")
456
+ ]
457
+
458
+ for course, prereq in prereq_map:
459
+ if course in self.valid_courses and course not in completed:
460
+ if prereq is None or prereq in completed:
461
+ next_courses.append(course)
462
+ return next_courses
463
+
464
+ # Years 2-4: Strict filtering by subject and level
465
+ available = []
466
+ ALLOWED_SUBJECTS = {"CS", "DS", "CY", "MATH"}
467
+
468
+ for cid in self.valid_courses:
469
+ if cid in completed:
470
+ continue
471
+
472
+ course_data = self.courses.get(cid, {})
473
+ subject = course_data.get('subject')
474
+
475
+ if subject not in ALLOWED_SUBJECTS:
476
+ continue
477
+
478
+ course_level = self._get_level(cid)
479
+
480
+ # FIX: Strict year-based level filtering
481
+ if year == 2:
482
+ if course_level < 2000 or course_level > 3999:
483
+ continue # Year 2: only 2000-3999
484
+ elif year == 3:
485
+ if course_level < 3000:
486
+ continue # Year 3: 3000+ only
487
+ elif year == 4:
488
+ if course_level < 4000:
489
+ continue # Year 4: 4000+ only (including CS5700)
490
+
491
+ available.append(cid)
492
+
493
+ return available
494
+
495
+ def _fix_plan_errors(self, plan: Dict, validation: Dict, student: StudentProfile) -> Dict:
496
+ if any("Mixed" in error for error in validation["errors"]):
497
+ return self._build_structured_plan(student, self._identify_track(student), None)
498
+ return plan
499
+
500
+ def _get_llm_course_suggestions(self, student: StudentProfile, track: str) -> List[str]:
501
+ requirements = self.CONCENTRATION_REQUIREMENTS.get(track, {})
502
+ all_options = set()
503
+ for reqs in requirements.values():
504
+ for key, courses in reqs.items():
505
+ if key.startswith("pick_"):
506
+ all_options.update(courses)
507
+
508
+ course_options_text = [
509
+ f"{cid}: {self.courses[cid].get('name', cid)} - {self.courses[cid].get('description', '')[:100].strip()}"
510
+ for cid in list(all_options)[:15] if cid in self.courses
511
+ ]
512
+
513
+ prompt = f"""Expert curriculum advisor ranking courses for student.
514
+
515
+ Student Profile:
516
+ - Career Goal: {student.career_goals}
517
+ - Interests: {', '.join(student.interests)}
518
+ - Difficulty: {student.preferred_difficulty}
519
+
520
+ Available Courses:
521
+ {chr(10).join(course_options_text)}
522
+
523
+ Return ONLY top 5 course IDs, one per line."""
524
+
525
+ try:
526
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(self.device)
527
+ with torch.no_grad():
528
+ outputs = self.llm.generate(
529
+ **inputs,
530
+ max_new_tokens=100,
531
+ temperature=0.2,
532
+ do_sample=True,
533
+ pad_token_id=self.tokenizer.eos_token_id
534
+ )
535
+ response = self.tokenizer.decode(outputs[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
536
+ suggested_courses = re.findall(r'([A-Z]{2,4}\d{4})', response)
537
+ return suggested_courses[:5]
538
+ except Exception as e:
539
+ print(f"LLM suggestion failed: {e}")
540
+ return list(all_options)[:5]
541
+
542
+ def _map_difficulty(self, preferred_difficulty: str) -> str:
543
+ return {"easy": "easy", "moderate": "medium", "challenging": "hard"}.get(preferred_difficulty.lower(), "medium")
544
+
545
+ def _calculate_course_load(self, time_commitment: int) -> int:
546
+ if time_commitment <= 20:
547
+ return 3
548
+ if time_commitment <= 40:
549
+ return 4
550
+ return 5
551
+
552
+ def _identify_track(self, student: StudentProfile) -> str:
553
+ if not hasattr(self, 'embedding_model') or self.embedding_model is None:
554
+ combined = f"{student.career_goals.lower()} {' '.join(student.interests).lower()}"
555
+ if any(word in combined for word in ['ai', 'ml', 'machine learning', 'data']):
556
+ return "ai_ml"
557
+ if any(word in combined for word in ['systems', 'distributed', 'backend']):
558
+ return "systems"
559
+ if any(word in combined for word in ['security', 'cyber']):
560
+ return "security"
561
+ return "ai_ml"
562
+
563
+ profile_text = f"{student.career_goals} {' '.join(student.interests)}"
564
+ profile_emb = self.embedding_model.encode(profile_text, convert_to_tensor=True)
565
+
566
+ track_descriptions = {
567
+ "ai_ml": "artificial intelligence machine learning deep learning neural networks data science",
568
+ "systems": "operating systems distributed systems networks compilers databases performance backend",
569
+ "security": "cybersecurity cryptography network security ethical hacking vulnerabilities"
570
+ }
571
+
572
+ best_track, best_score = "ai_ml", -1.0
573
+ for track, description in track_descriptions.items():
574
+ track_emb = self.embedding_model.encode(description, convert_to_tensor=True)
575
+ score = float(util.cos_sim(profile_emb, track_emb))
576
+ if score > best_score:
577
+ best_score, best_track = score, track
578
+
579
+ return best_track
580
+
581
+ def _compute_semantic_scores(self, student: StudentProfile) -> Dict[str, float]:
582
+ query_text = f"{student.career_goals} {' '.join(student.interests)}"
583
+ query_emb = self.embedding_model.encode(query_text, convert_to_tensor=True)
584
+ similarities = util.cos_sim(query_emb, self.course_embeddings)[0]
585
+ return {cid: float(similarities[idx]) for idx, cid in enumerate(self.valid_courses)}
586
+
587
+ def _generate_explanation(self, student: StudentProfile, plan: Dict, track: str, plan_type: str) -> str:
588
+ return f"{plan_type.title()} plan for the {track} track, tailored to your goal of becoming a {student.career_goals}."
589
+
590
+ def validate_plan(self, plan: Dict, student: StudentProfile = None) -> Dict[str, List[str]]:
591
+ issues = {"errors": [], "warnings": [], "info": []}
592
+ all_courses = [course for year in plan.values() for sem in year.values() for course in sem if isinstance(sem, list)]
593
+
594
+ # Check for mixed tracks
595
+ for track_type, tracks in self.COURSE_TRACKS.items():
596
+ tracks_used = {name for name, courses in tracks.items() if any(c in all_courses for c in courses)}
597
+ if len(tracks_used) > 1:
598
+ issues["errors"].append(f"Mixed {track_type} tracks: {', '.join(tracks_used)}. Choose one sequence.")
599
+
600
+ # Validate prerequisites
601
+ completed_for_validation = set(student.completed_courses) if student else set()
602
+ for year in range(1, 5):
603
+ for sem in ["fall", "spring"]:
604
+ year_key = f"year_{year}"
605
+ sem_courses = plan.get(year_key, {}).get(sem, [])
606
+ for course in sem_courses:
607
+ if course in self.curriculum_graph:
608
+ prereqs = set(self.curriculum_graph.predecessors(course))
609
+ if not prereqs.issubset(self._get_completed_with_equivalents(completed_for_validation)):
610
+ missing = prereqs - completed_for_validation
611
+ issues["errors"].append(f"{course} in Year {year} {sem} is missing prereqs: {', '.join(missing)}")
612
+ completed_for_validation.update(sem_courses)
613
+
614
+ return issues
615
+
616
+ def _finalize_plan(self, plan: Dict, explanation: str, validation: Dict = None) -> Dict:
617
+ structured_plan = {
618
+ "reasoning": explanation,
619
+ "validation": validation or {"errors": [], "warnings": [], "info": []}
620
+ }
621
+
622
+ complexities = []
623
+ for year in range(1, 5):
624
+ year_key = f"year_{year}"
625
+ structured_plan[year_key] = {
626
+ "fall": plan.get(year_key, {}).get("fall", []),
627
+ "spring": plan.get(year_key, {}).get("spring", []),
628
+ "summer": "co-op" if year in [2, 3] else []
629
+ }
630
+
631
+ for sem in ["fall", "spring"]:
632
+ courses = structured_plan[year_key][sem]
633
+ if courses:
634
+ sem_complexity = sum(self.courses.get(c, {}).get('complexity', 50) for c in courses)
635
+ complexities.append(sem_complexity)
636
+
637
+ structured_plan["complexity_analysis"] = {
638
+ "average_semester_complexity": float(np.mean(complexities)) if complexities else 0,
639
+ "peak_semester_complexity": float(np.max(complexities)) if complexities else 0,
640
+ "total_complexity": float(np.sum(complexities)) if complexities else 0,
641
+ "balance_score (std_dev)": float(np.std(complexities)) if complexities else 0
642
+ }
643
+
644
+ structured_plan["metadata"] = {
645
+ "generated": datetime.now().isoformat(),
646
+ "valid": len(validation.get("errors", [])) == 0 if validation else True,
647
+ }
648
+
649
+ return {"pathway": structured_plan}
650
+
651
+ class CurriculumOptimizer(HybridOptimizer):
652
+ """Compatibility wrapper"""
653
+ def __init__(self):
654
+ super().__init__()
655
+
656
+ def generate_plan(self, student: StudentProfile, track_override: Optional[str] = None) -> Dict:
657
+ return self.generate_enhanced_rule_plan(student, track_override)
src/neu_graph_clean10.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ebd3c024667aedd28338a61e7e87bdd78c7db3dce94fb1a920e11eb7bdc985d
3
+ size 156590
src/ui.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import os
4
+ import time
5
+ import json
6
+ import yaml
7
+ from datetime import datetime
8
+ from typing import Dict, Set, Optional
9
+
10
+ # Import the optimizer and visualizer
11
+ from curriculum_optimizer import HybridOptimizer, StudentProfile
12
+ from interactive_visualizer import CurriculumVisualizer
13
+
14
+ # --- Page Configuration ---
15
+ st.set_page_config(page_title="Curriculum Optimizer", layout="wide", initial_sidebar_state="expanded")
16
+
17
+ # Initialize session state
18
+ if "display_plan" not in st.session_state:
19
+ st.session_state.display_plan = None
20
+ if "metrics" not in st.session_state:
21
+ st.session_state.metrics = None
22
+ if "reasoning" not in st.session_state:
23
+ st.session_state.reasoning = ""
24
+ if "graph_data_loaded" not in st.session_state:
25
+ st.session_state.graph_data_loaded = False
26
+ if "last_profile" not in st.session_state:
27
+ st.session_state.last_profile = None
28
+ if "visualizer" not in st.session_state:
29
+ st.session_state.visualizer = None
30
+ if "selected_track" not in st.session_state:
31
+ st.session_state.selected_track = "general" # Default to general
32
+
33
+ # Title
34
+ st.title("🧑‍🎓 Next-Gen Curriculum Optimizer")
35
+
36
+ # --- Caching and Initialization ---
37
+ @st.cache_resource
38
+ def get_optimizer():
39
+ """Loads and caches the main optimizer class and its models."""
40
+ try:
41
+ optimizer = HybridOptimizer()
42
+ optimizer.load_models()
43
+ return optimizer
44
+ except Exception as e:
45
+ st.error(f"Fatal error during model loading: {e}")
46
+ st.info("Please ensure you have the required libraries installed.")
47
+ st.stop()
48
+ return None
49
+
50
+ optimizer = get_optimizer()
51
+
52
+ # --- DYNAMIC HELPER FUNCTIONS ---
53
+
54
+ def check_requirements_satisfaction(plan: Dict, track: str) -> Dict:
55
+ """
56
+ Check which requirements are satisfied by the plan.
57
+ This is now dynamic based on the optimizer's config.
58
+ """
59
+ if not optimizer:
60
+ return {}
61
+
62
+ all_courses = []
63
+ for year_key, year_data in plan.items():
64
+ if year_key.startswith("year_"):
65
+ all_courses.extend(year_data.get("fall", []))
66
+ all_courses.extend(year_data.get("spring", []))
67
+ all_courses_set = set(all_courses)
68
+
69
+ # Get the correct requirements dictionary
70
+ if track == "general":
71
+ req_data = {
72
+ "foundations": {"required": ["CS1800", "CS2500", "CS2510", "CS2800"]},
73
+ "core": {"required": ["CS3000", "CS3500", "CS3650"]},
74
+ "math": {"required": ["MATH1341", "MATH1342"], "pick_1_from": ["MATH2331", "MATH3081"]}
75
+ }
76
+ elif track == "game_dev":
77
+ # Use ai_ml as a base for game_dev
78
+ req_data = optimizer.CONCENTRATION_REQUIREMENTS.get("ai_ml", {})
79
+ else:
80
+ req_data = optimizer.CONCENTRATION_REQUIREMENTS.get(track, {})
81
+
82
+ satisfaction_report = {}
83
+ for category, reqs in req_data.items():
84
+ report = {}
85
+ if "required" in reqs:
86
+ req_list = reqs["required"]
87
+ report["required"] = req_list
88
+ report["completed"] = list(all_courses_set & set(req_list))
89
+ report["is_satisfied"] = all_courses_set.issuperset(req_list)
90
+
91
+ for key, courses in reqs.items():
92
+ if key.startswith("pick_"):
93
+ try:
94
+ num_to_pick = int(key.split("_")[1])
95
+ except Exception:
96
+ num_to_pick = 1
97
+
98
+ completed_in_pick = list(all_courses_set & set(courses))
99
+ report[key] = {
100
+ "options": courses,
101
+ "completed": completed_in_pick,
102
+ "count": f"{len(completed_in_pick)} of {num_to_pick}",
103
+ "is_satisfied": len(completed_in_pick) >= num_to_pick
104
+ }
105
+ satisfaction_report[category] = report
106
+
107
+ return satisfaction_report
108
+
109
+ def export_plan_yaml(plan: Dict, profile: StudentProfile, validation: Dict = None, track: str = "general") -> str:
110
+ """Export plan in structured YAML format for verification"""
111
+
112
+ # Build structured plan data
113
+ structured_plan = {
114
+ "student_profile": {
115
+ "name": profile.name if hasattr(profile, 'name') else "Student",
116
+ "gpa": profile.current_gpa,
117
+ "career_goal": profile.career_goals,
118
+ "interests": profile.interests,
119
+ "completed_courses": profile.completed_courses,
120
+ "time_commitment": profile.time_commitment,
121
+ "preferred_difficulty": profile.preferred_difficulty
122
+ },
123
+ "plan_metadata": {
124
+ "generated": datetime.now().isoformat(),
125
+ "track": track, # --- FIX: Now dynamic ---
126
+ "total_credits": 0,
127
+ "validation_status": "valid" if not validation.get("errors") else "has_errors"
128
+ },
129
+ "validation": validation if validation else {"errors": [], "warnings": []},
130
+ "semesters": [],
131
+ "course_details": {}
132
+ }
133
+
134
+ # Build semester list with full details
135
+ total_credits = 0
136
+ for year in range(1, 5):
137
+ year_key = f"year_{year}"
138
+ if year_key in plan:
139
+ # Fall
140
+ fall_courses = plan[year_key].get("fall", [])
141
+ if fall_courses:
142
+ semester_data = {"year": year, "term": "fall", "courses": []}
143
+ for course_id in fall_courses:
144
+ course_info = optimizer.courses.get(course_id, {})
145
+ course_detail = {
146
+ "id": course_id,
147
+ "name": course_info.get("name", "Unknown"),
148
+ "credits": course_info.get("maxCredits", 4),
149
+ "complexity": course_info.get("complexity", 0),
150
+ "prerequisites": list(optimizer.curriculum_graph.predecessors(course_id)) if course_id in optimizer.curriculum_graph else []
151
+ }
152
+ semester_data["courses"].append(course_detail)
153
+ total_credits += course_detail["credits"]
154
+ structured_plan["course_details"][course_id] = course_detail
155
+
156
+ semester_data["semester_credits"] = sum(c["credits"] for c in semester_data["courses"])
157
+ semester_data["semester_complexity"] = sum(c["complexity"] for c in semester_data["courses"])
158
+ structured_plan["semesters"].append(semester_data)
159
+
160
+ # Spring
161
+ spring_courses = plan[year_key].get("spring", [])
162
+ if spring_courses:
163
+ semester_data = {"year": year, "term": "spring", "courses": []}
164
+ for course_id in spring_courses:
165
+ course_info = optimizer.courses.get(course_id, {})
166
+ course_detail = {
167
+ "id": course_id,
168
+ "name": course_info.get("name", "Unknown"),
169
+ "credits": course_info.get("maxCredits", 4),
170
+ "complexity": course_info.get("complexity", 0),
171
+ "prerequisites": list(optimizer.curriculum_graph.predecessors(course_id)) if course_id in optimizer.curriculum_graph else []
172
+ }
173
+ semester_data["courses"].append(course_detail)
174
+ total_credits += course_detail["credits"]
175
+ structured_plan["course_details"][course_id] = course_detail
176
+
177
+ semester_data["semester_credits"] = sum(c["credits"] for c in semester_data["courses"])
178
+ semester_data["semester_complexity"] = sum(c["complexity"] for c in semester_data["courses"])
179
+ structured_plan["semesters"].append(semester_data)
180
+
181
+ # Add summer/co-op
182
+ if year in [2, 3]:
183
+ structured_plan["semesters"].append({
184
+ "year": year, "term": "summer", "activity": "co-op", "courses": []
185
+ })
186
+
187
+ structured_plan["plan_metadata"]["total_credits"] = total_credits
188
+
189
+ # Calculate requirement satisfaction
190
+ # --- FIX: Pass the dynamic track ---
191
+ requirements_met = check_requirements_satisfaction(plan, track=track)
192
+ structured_plan["requirements_satisfaction"] = requirements_met
193
+
194
+ return yaml.dump(structured_plan, default_flow_style=False, sort_keys=False)
195
+
196
+
197
+ # --- UI TABS ---
198
+ tab1, tab2, tab3 = st.tabs(["📝 Plan Generator", "🗺️ Curriculum Map", "📊 Analytics"])
199
+
200
+ with tab1:
201
+ # --- SIDEBAR FOR STUDENT PROFILE ---
202
+ with st.sidebar:
203
+ st.header("Student Profile")
204
+ name = st.text_input("Name", "John, son of Jane")
205
+ gpa = st.slider("GPA", 0.0, 4.0, 3.0, 0.1)
206
+ career_goal = st.text_area("Career Goal", " ")
207
+ interests = st.text_input("Interests (comma-separated)", " ")
208
+ learning_style = st.selectbox("Learning Style", ["Visual", "Hands-on", "Auditory"])
209
+ time_commit = st.number_input("Weekly Study Hours", 10, 60, 40, 5)
210
+ difficulty = st.selectbox("Preferred Difficulty", ["easy", "moderate", "challenging"])
211
+ completed_courses_input = st.text_area("Completed Courses (comma-separated)", " ")
212
+
213
+ # Show profile impact
214
+ st.markdown("---")
215
+ st.markdown("**Profile Impact:**")
216
+ if time_commit < 20:
217
+ st.info("🕒 Part-time load (3 courses/semester)")
218
+ elif time_commit >= 40:
219
+ st.info("🔥 Intensive load (up to 5 courses/semester)")
220
+ else:
221
+ st.info("📚 Standard load (4 courses/semester)")
222
+
223
+ if difficulty == "easy":
224
+ st.info("😌 Focuses on foundational courses")
225
+ elif difficulty == "challenging":
226
+ st.info("🚀 Includes advanced/specialized courses")
227
+ else:
228
+ st.info("⚖️ Balanced difficulty progression")
229
+
230
+ # --- MAIN PAGE CONTENT ---
231
+
232
+ # 1. LOAD DATA
233
+ st.subheader("1. Load Curriculum Data")
234
+ uploaded_file = st.file_uploader("Upload `.pkl` file in the files section of this project", type=["pkl"])
235
+
236
+ if uploaded_file and not st.session_state.graph_data_loaded:
237
+ with st.spinner("Loading curriculum data and preparing embeddings..."):
238
+ try:
239
+ graph_data = pickle.load(uploaded_file)
240
+ optimizer.load_data(graph_data)
241
+ st.session_state.visualizer = CurriculumVisualizer(graph_data)
242
+ st.session_state.graph_data = graph_data
243
+ st.session_state.graph_data_loaded = True
244
+ st.success(f"Successfully loaded and processed '{uploaded_file.name}'!")
245
+ time.sleep(1)
246
+ st.rerun()
247
+ except Exception as e:
248
+ st.error(f"Error processing .pkl file: {e}")
249
+ st.session_state.graph_data_loaded = False
250
+ elif st.session_state.graph_data_loaded:
251
+ st.success("Curriculum data is loaded and ready.")
252
+
253
+ # 2. SELECT TRACK (NEW SECTION)
254
+ st.subheader("2. Select a Specialization")
255
+ if not st.session_state.graph_data_loaded:
256
+ st.info("Please load a curriculum file first.")
257
+ else:
258
+ # Map user-friendly names to the internal keys
259
+ track_options = {
260
+ "general": "🤖 General CS (Broadest Focus)",
261
+ "ai_ml": "🧠 Artificial Intelligence & ML",
262
+ "security": "🔒 Cybersecurity",
263
+ "systems": "⚙️ Systems & Networks",
264
+ "game_dev": "🎮 Game Design & Development"
265
+ }
266
+
267
+ selected_track_key = st.selectbox(
268
+ "Choose your focus area (optional):",
269
+ options=track_options.keys(),
270
+ format_func=lambda key: track_options[key], # Shows the friendly name
271
+ index=0 # Default to "General"
272
+ )
273
+ st.session_state.selected_track = selected_track_key
274
+
275
+ # 3. GENERATE PLAN
276
+ st.subheader("3. Generate a Plan")
277
+ if not st.session_state.graph_data_loaded:
278
+ st.info("Please load a curriculum file above to enable plan generation.")
279
+ else:
280
+ # Create student profile
281
+ profile = StudentProfile(
282
+ completed_courses=[c.strip().upper() for c in completed_courses_input.split(',') if c.strip()],
283
+ current_gpa=gpa,
284
+ interests=[i.strip() for i in interests.split(',') if i.strip()],
285
+ career_goals=career_goal,
286
+ learning_style=learning_style,
287
+ time_commitment=time_commit,
288
+ preferred_difficulty=difficulty
289
+ )
290
+
291
+ # Get the selected track from session state
292
+ selected_track = st.session_state.get("selected_track", "general")
293
+
294
+ # Check if profile or track changed
295
+ profile_changed = (st.session_state.last_profile != profile) or \
296
+ (st.session_state.last_track != selected_track)
297
+
298
+ if profile_changed:
299
+ st.session_state.last_profile = profile
300
+ st.session_state.last_track = selected_track
301
+
302
+ col1, col2, col3 = st.columns(3)
303
+
304
+ if col1.button("🧠 AI-Optimized Plan", use_container_width=True, type="primary"):
305
+ with st.spinner(f"🚀 Performing AI-optimization for '{track_options[selected_track]}' track..."):
306
+ start_time = time.time()
307
+ # --- FIX: Pass selected_track ---
308
+ result = optimizer.generate_llm_plan(profile, selected_track)
309
+ generation_time = time.time() - start_time
310
+
311
+ plan_raw = result.get('pathway', {})
312
+ st.session_state.reasoning = plan_raw.get("reasoning", "")
313
+ st.session_state.metrics = plan_raw.get("complexity_analysis", {})
314
+ st.session_state.display_plan = plan_raw
315
+ st.session_state.plan_type = "AI-Optimized"
316
+ st.session_state.generation_time = generation_time
317
+ st.success(f"🎉 AI-optimized plan generated in {generation_time:.1f}s!")
318
+
319
+ if col2.button("⚡ Smart Rule-Based Plan", use_container_width=True):
320
+ with st.spinner(f"Generating rule-based plan for '{track_options[selected_track]}' track..."):
321
+ start_time = time.time()
322
+ # --- FIX: Pass selected_track ---
323
+ result = optimizer.generate_simple_plan(profile, selected_track)
324
+ generation_time = time.time() - start_time
325
+
326
+ plan_raw = result.get('pathway', {})
327
+ st.session_state.reasoning = plan_raw.get("reasoning", "")
328
+ st.session_state.metrics = plan_raw.get("complexity_analysis", {})
329
+ st.session_state.display_plan = plan_raw
330
+ st.session_state.plan_type = "Smart Rule-Based"
331
+ st.session_state.generation_time = generation_time
332
+ st.success(f"🎉 Smart rule-based plan generated in {generation_time:.1f}s!")
333
+
334
+ if col3.button("🔄 Clear Plan", use_container_width=True):
335
+ st.session_state.display_plan = None
336
+ st.session_state.metrics = None
337
+ st.session_state.reasoning = ""
338
+ st.rerun()
339
+
340
+ # Show profile change notification
341
+ if st.session_state.display_plan and profile_changed:
342
+ st.warning("⚠️ Student profile or track changed! Generate a new plan to see updated recommendations.")
343
+
344
+ # DISPLAY RESULTS
345
+ if st.session_state.display_plan:
346
+ st.subheader(f"📚 {st.session_state.get('plan_type', 'Optimized')} Degree Plan")
347
+
348
+ # Display generation info
349
+ col_info1, col_info2, col_info3 = st.columns(3)
350
+ with col_info1:
351
+ st.metric("Generation Time", f"{st.session_state.get('generation_time', 0):.1f}s")
352
+ with col_info2:
353
+ st.metric("Plan Type", st.session_state.get('plan_type', 'Unknown'))
354
+ with col_info3:
355
+ if time_commit < 20:
356
+ load_type = "Part-time"
357
+ elif time_commit >= 40:
358
+ load_type = "Intensive"
359
+ else:
360
+ load_type = "Standard"
361
+ st.metric("Course Load", load_type)
362
+
363
+ # Display reasoning and metrics
364
+ if st.session_state.reasoning or st.session_state.metrics:
365
+ st.markdown("##### 📊 Plan Analysis")
366
+
367
+ if st.session_state.reasoning:
368
+ st.info(f"**Strategy:** {st.session_state.reasoning}")
369
+
370
+ if st.session_state.metrics:
371
+ m = st.session_state.metrics
372
+ c1, c2, c3, c4 = st.columns(4)
373
+
374
+ c1.metric("Avg Complexity", f"{m.get('average_semester_complexity', 0):.1f}")
375
+ c2.metric("Peak Complexity", f"{m.get('peak_semester_complexity', 0):.1f}")
376
+ c3.metric("Total Complexity", f"{m.get('total_complexity', 0):.0f}")
377
+ c4.metric("Balance Score", f"{m.get('balance_score (std_dev)', 0):.2f}")
378
+
379
+ st.divider()
380
+
381
+ # Display the actual plan
382
+ plan = st.session_state.display_plan
383
+ total_courses = 0
384
+
385
+ for year_num in range(1, 5):
386
+ year_key = f"year_{year_num}"
387
+ year_data = plan.get(year_key, {})
388
+
389
+ st.markdown(f"### Year {year_num}")
390
+ col_fall, col_spring, col_summer = st.columns(3)
391
+
392
+ # Fall semester
393
+ with col_fall:
394
+ fall_courses = year_data.get("fall", [])
395
+ st.markdown("**🍂 Fall Semester**")
396
+ if fall_courses:
397
+ for course_id in fall_courses:
398
+ if course_id in optimizer.courses:
399
+ course_data = optimizer.courses[course_id]
400
+ course_name = course_data.get("name", course_id)
401
+ st.write(f"• **{course_id}**: {course_name}")
402
+ total_courses += 1
403
+ else:
404
+ st.write(f"• {course_id}")
405
+ total_courses += 1
406
+ else:
407
+ st.write("*No courses scheduled*")
408
+
409
+ # Spring semester
410
+ with col_spring:
411
+ spring_courses = year_data.get("spring", [])
412
+ st.markdown("**🌸 Spring Semester**")
413
+ if spring_courses:
414
+ for course_id in spring_courses:
415
+ if course_id in optimizer.courses:
416
+ course_data = optimizer.courses[course_id]
417
+ course_name = course_data.get("name", course_id)
418
+ st.write(f"• **{course_id}**: {course_name}")
419
+ total_courses += 1
420
+ else:
421
+ st.write(f"• {course_id}")
422
+ total_courses += 1
423
+ else:
424
+ st.write("*No courses scheduled*")
425
+
426
+ # Summer
427
+ with col_summer:
428
+ summer = year_data.get("summer", [])
429
+ st.markdown("**☀️ Summer**")
430
+ if summer == "co-op":
431
+ st.write("🏢 *Co-op Experience*")
432
+ elif summer:
433
+ # This case isn't really used by the optimizer, but good to have
434
+ st.write("*Summer Classes*")
435
+ else:
436
+ st.write("*Break*")
437
+
438
+ # Summary and export
439
+ st.divider()
440
+ col_export1, col_export2 = st.columns(2)
441
+
442
+ with col_export1:
443
+ st.metric("Total Courses", total_courses)
444
+
445
+ with col_export2:
446
+ col_yaml, col_json = st.columns(2)
447
+
448
+ with col_yaml:
449
+ # --- FIX: Get validation from the plan object, DO NOT re-run validate_plan() ---
450
+ validation = st.session_state.display_plan.get("validation", {"errors": [], "warnings": []})
451
+
452
+ yaml_data = export_plan_yaml(
453
+ st.session_state.display_plan,
454
+ profile,
455
+ validation,
456
+ st.session_state.get("selected_track", "general") # Pass track
457
+ )
458
+ st.download_button(
459
+ label="📥 Export as YAML",
460
+ data=yaml_data,
461
+ file_name=f"curriculum_plan_{name.replace(' ', '_')}.yaml",
462
+ mime="text/yaml",
463
+ use_container_width=True
464
+ )
465
+
466
+ with col_json:
467
+ export_data = {
468
+ "student_profile": {
469
+ "name": name, "gpa": gpa, "career_goals": career_goal,
470
+ "interests": interests, "learning_style": learning_style,
471
+ "time_commitment": time_commit, "preferred_difficulty": difficulty,
472
+ "completed_courses": completed_courses_input
473
+ },
474
+ "plan": st.session_state.display_plan,
475
+ "metrics": st.session_state.metrics,
476
+ "generation_info": {
477
+ "plan_type": st.session_state.get('plan_type', 'Unknown'),
478
+ "generation_time": st.session_state.get('generation_time', 0),
479
+ "selected_track": st.session_state.get("selected_track", "general")
480
+ }
481
+ }
482
+ plan_json = json.dumps(export_data, indent=2)
483
+ st.download_button(
484
+ label="📥 Export as JSON",
485
+ data=plan_json,
486
+ file_name=f"curriculum_plan_{name.replace(' ', '_')}.json",
487
+ mime="application/json",
488
+ use_container_width=True
489
+ )
490
+
491
+ # --- TAB 2: CURRICULUM MAP ---
492
+ with tab2:
493
+ st.subheader("🗺️ Interactive Curriculum Dependency Graph")
494
+
495
+ if not st.session_state.graph_data_loaded:
496
+ st.info("Please load curriculum data in the Plan Generator tab first.")
497
+ else:
498
+ # Create visualization
499
+ if st.session_state.visualizer:
500
+ critical_path = st.session_state.visualizer.find_critical_path()
501
+ if critical_path:
502
+ st.info(f"Global Critical Path ({len(critical_path)} courses): {' → '.join(critical_path[:7])}...")
503
+
504
+ # Create the plot
505
+ fig = st.session_state.visualizer.create_interactive_plot(critical_path)
506
+ st.plotly_chart(fig, use_container_width=True)
507
+
508
+ # Legend
509
+ with st.expander("📖 How to Read This Graph"):
510
+ st.markdown("""
511
+ **Node (Circle) Size**: Blocking factor - larger circles block more future courses
512
+ **Node Color**: Complexity score - darker = more complex
513
+ **Lines**: Prerequisite relationships
514
+ **Red Path**: Critical path (longest chain)
515
+ **Hover over nodes**: See detailed metrics for each course
516
+ """)
517
+
518
+ # --- TAB 3: ANALYTICS ---
519
+ with tab3:
520
+ st.subheader("📊 Curriculum Analytics Dashboard")
521
+
522
+ if not st.session_state.graph_data_loaded:
523
+ st.info("Please load curriculum data in the Plan Generator tab first.")
524
+ else:
525
+ # Overall metrics
526
+ col1, col2, col3, col4 = st.columns(4)
527
+
528
+ graph = st.session_state.graph_data
529
+ total_courses = graph.number_of_nodes()
530
+ total_prereqs = graph.number_of_edges()
531
+
532
+ col1.metric("Total Courses", total_courses)
533
+ col2.metric("Total Prerequisites", total_prereqs)
534
+ col3.metric("Avg Prerequisites", f"{total_prereqs/total_courses:.1f}")
535
+
536
+ if st.session_state.visualizer:
537
+ total_complexity = sum(
538
+ st.session_state.visualizer.calculate_metrics(n)['complexity']
539
+ for n in graph.nodes()
540
+ )
541
+ col4.metric("Curriculum Complexity", f"{total_complexity:,.0f}")
542
+
543
+ st.divider()
544
+
545
+ # Most complex courses
546
+ col1, col2 = st.columns(2)
547
+
548
+ with col1:
549
+ st.subheader("Most Complex Courses")
550
+ if st.session_state.visualizer:
551
+ complexities = []
552
+ for node in graph.nodes():
553
+ metrics = st.session_state.visualizer.calculate_metrics(node)
554
+ complexities.append({
555
+ 'course': node,
556
+ 'name': graph.nodes[node].get('name', ''),
557
+ 'complexity': metrics['complexity'],
558
+ 'blocking': metrics['blocking']
559
+ })
560
+
561
+ complexities.sort(key=lambda x: x['complexity'], reverse=True)
562
+
563
+ for item in complexities[:10]:
564
+ st.write(f"**{item['course']}**: {item['name']}")
565
+ prog_col1, prog_col2 = st.columns([3, 1])
566
+ with prog_col1:
567
+ st.progress(min(item['complexity']/100, 1.0)) # Adjusted scale
568
+ with prog_col2:
569
+ st.caption(f"Blocks: {item['blocking']}")
570
+
571
+ with col2:
572
+ st.subheader("Bottleneck Courses")
573
+ st.caption("(High blocking factor)")
574
+
575
+ if st.session_state.visualizer:
576
+ bottlenecks = sorted(complexities, key=lambda x: x['blocking'], reverse=True)
577
+
578
+ for item in bottlenecks[:10]:
579
+ st.write(f"**{item['course']}**: {item['name']}")
580
+ st.info(f"Blocks {item['blocking']} future courses")
581
+
582
+ # Plan vs Global Comparison
583
+ if st.session_state.display_plan:
584
+ st.divider()
585
+ st.subheader("📊 Metric System Comparison")
586
+ st.caption("Comparing metrics for the entire curriculum vs. metrics only within your generated plan.")
587
+
588
+ plan_courses: Set[str] = set()
589
+ for year_key, year_data in st.session_state.display_plan.items():
590
+ if year_key.startswith("year_"):
591
+ plan_courses.update(year_data.get("fall", []))
592
+ plan_courses.update(year_data.get("spring", []))
593
+
594
+ comparison = st.session_state.visualizer.compare_metric_systems(plan_courses)
595
+
596
+ col1, col2 = st.columns(2)
597
+
598
+ with col1:
599
+ st.metric(
600
+ "Critical Path Match",
601
+ "✅ Yes" if comparison['critical_path_match'] else "❌ No"
602
+ )
603
+ st.caption("Global critical path (first 5):")
604
+ st.code(' → '.join(comparison['global_critical']))
605
+
606
+ with col2:
607
+ st.metric(
608
+ "Major Metric Differences",
609
+ len(comparison['major_differences'])
610
+ )
611
+ st.caption("Plan-specific critical path (first 5):")
612
+ st.code(' → '.join(comparison['plan_critical']))
613
+
614
+ if comparison['major_differences']:
615
+ with st.expander(f"View {len(comparison['major_differences'])} courses with >50% metric difference"):
616
+ for diff in comparison['major_differences']:
617
+ st.write(f"**{diff['course']}**: Global blocking={diff['global_blocking']}, Plan blocking={diff['plan_blocking']}")
618
+
619
+ # Footer
620
+ st.divider()
621
+ st.caption("🚀 Powered by Students, For Students")