Clementio commited on
Commit
04e98a7
·
verified ·
1 Parent(s): c5ef1a8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py CHANGED
@@ -7,3 +7,234 @@ import networkx as nx
7
  import numpy as np
8
  from huggingface_hub import hf_hub_download
9
  from typing import Dict, List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import numpy as np
8
  from huggingface_hub import hf_hub_download
9
  from typing import Dict, List, Optional
10
+
11
+ st.set_page_config(page_title='Logic Engine', page_icon='🧠', layout='wide')
12
+
13
+ HF_REPO = 'Clementio/PLRS'
14
+
15
+ @st.cache_resource
16
+ def load_model():
17
+ config_path = hf_hub_download(repo_id=HF_REPO, filename='config.json')
18
+ with open(config_path) as f:
19
+ config = json.load(f)
20
+ model_path = hf_hub_download(repo_id=HF_REPO, filename='sakt_model.pt')
21
+ class SAKT(nn.Module):
22
+ def __init__(self, num_skills, embed_dim, num_heads, num_layers, max_seq_len, dropout):
23
+ super(SAKT, self).__init__()
24
+ self.num_skills = num_skills
25
+ self.interaction_embed = nn.Embedding(num_skills * 2 + 1, embed_dim, padding_idx=0)
26
+ self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
27
+ self.pos_embed = nn.Embedding(max_seq_len + 1, embed_dim)
28
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True, dim_feedforward=embed_dim * 4, norm_first=True)
29
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)
30
+ self.dropout = nn.Dropout(dropout)
31
+ self.output = nn.Linear(embed_dim, 1)
32
+ def forward(self, interactions, target_skills, mask):
33
+ batch_size, seq_len = interactions.shape
34
+ positions = torch.arange(seq_len, device=interactions.device).unsqueeze(0).expand(batch_size, -1)
35
+ x = self.interaction_embed(interactions)
36
+ x = x + self.pos_embed(positions)
37
+ x = x * mask.unsqueeze(-1).float()
38
+ x = self.dropout(x)
39
+ causal_mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
40
+ x = self.transformer(x, mask=causal_mask, is_causal=False)
41
+ x = x * mask.unsqueeze(-1).float()
42
+ x = x + self.skill_embed(target_skills)
43
+ return self.output(x).squeeze(-1)
44
+ device = torch.device('cpu')
45
+ model = SAKT(num_skills=config['num_skills'], embed_dim=config['embed_dim'], num_heads=config['num_heads'], num_layers=config['num_layers'], max_seq_len=config['max_seq_len'], dropout=config['dropout'])
46
+ model.load_state_dict(torch.load(model_path, map_location=device))
47
+ model.eval()
48
+ return model, config, device
49
+
50
+ @st.cache_resource
51
+ def load_knowledge_maps():
52
+ def load_dag(path):
53
+ with open(path) as f:
54
+ data = json.load(f)
55
+ G = nx.DiGraph()
56
+ for node in data['nodes']:
57
+ G.add_node(node['id'], label=node['label'], level=node['level'], term=node['term'])
58
+ for edge in data['edges']:
59
+ G.add_edge(edge['from'], edge['to'])
60
+ return G
61
+ return load_dag('knowledge_maps/math_dag.json'), load_dag('knowledge_maps/cs_dag.json')
62
+
63
+ @st.cache_data
64
+ def load_skill_encoder():
65
+ return pd.read_csv('data/skill_encoder.csv')
66
+
67
+ class MasteryVector:
68
+ def __init__(self, graph, threshold=0.70):
69
+ self.graph = graph
70
+ self.threshold = threshold
71
+ self.mastery = {node: 0.0 for node in graph.nodes}
72
+ def update(self, topic_id, probability):
73
+ if topic_id in self.mastery: self.mastery[topic_id] = probability
74
+ def is_mastered(self, topic_id):
75
+ return self.mastery.get(topic_id, 0.0) >= self.threshold
76
+ def get_mastery(self, topic_id):
77
+ return self.mastery.get(topic_id, 0.0)
78
+ def get_mastery_summary(self):
79
+ mastered = [t for t in self.mastery if self.is_mastered(t)]
80
+ return {'total_topics': len(self.mastery), 'mastered': len(mastered), 'mastery_rate': round(len(mastered)/len(self.mastery), 3), 'mastered_topics': mastered}
81
+
82
+ class DAGConstraintLayer:
83
+ def __init__(self, graph, threshold=0.70):
84
+ self.graph = graph
85
+ self.threshold = threshold
86
+ def validate(self, topic_id, mastery_vector):
87
+ if topic_id not in self.graph.nodes: return False, 'Topic not found.'
88
+ prerequisites = list(self.graph.predecessors(topic_id))
89
+ label = self.graph.nodes[topic_id].get('label', topic_id)
90
+ if not prerequisites: return True, f'✅ Foundational topic — no prerequisites.'
91
+ unmastered = [(self.graph.nodes[p].get('label',p), mastery_vector.get_mastery(p)) for p in prerequisites if not mastery_vector.is_mastered(p)]
92
+ if unmastered:
93
+ gaps = ', '.join([f"{lbl} ({m:.0%} mastered, need {self.threshold:.0%})" for lbl,m in unmastered])
94
+ return False, f'❌ Prerequisites not met: {gaps}'
95
+ prereq_labels = [self.graph.nodes[p].get('label',p) for p in prerequisites]
96
+ return True, f'✅ Prerequisites mastered: {", ".join(prereq_labels)}'
97
+
98
+ class RankingFunction:
99
+ def __init__(self, graph, threshold=0.70, w_gap=0.40, w_ready=0.35, w_downstream=0.25):
100
+ self.graph=graph; self.threshold=threshold; self.w_gap=w_gap; self.w_ready=w_ready; self.w_downstream=w_downstream
101
+ scores = {n: len(nx.descendants(graph, n)) for n in graph.nodes}
102
+ mx = max(scores.values()) if scores else 1
103
+ self._downstream = {n: s/mx for n,s in scores.items()}
104
+ def score(self, topic_id, mastery_vector):
105
+ current = mastery_vector.get_mastery(topic_id)
106
+ gap = min(max(0.0, self.threshold-current)/self.threshold, 1.0)
107
+ prereqs = list(self.graph.predecessors(topic_id))
108
+ readiness = 1.0 if not prereqs else sum(1 for p in prereqs if mastery_vector.is_mastered(p))/len(prereqs)
109
+ downstream = self._downstream.get(topic_id, 0.0)
110
+ return round(self.w_gap*gap + self.w_ready*readiness + self.w_downstream*downstream, 3)
111
+
112
+ class LearningRecommendationPipeline:
113
+ def __init__(self, graph, threshold=0.70, top_n=5):
114
+ self.graph=graph; self.constraint=DAGConstraintLayer(graph,threshold); self.ranker=RankingFunction(graph,threshold); self.top_n=top_n
115
+ def run(self, mastery_vector):
116
+ approved, vetoed = [], []
117
+ for topic_id in self.graph.nodes:
118
+ is_approved, reasoning = self.constraint.validate(topic_id, mastery_vector)
119
+ entry = {'topic_id': topic_id, 'topic_label': self.graph.nodes[topic_id].get('label', topic_id), 'mastery': round(mastery_vector.get_mastery(topic_id),3), 'reasoning': reasoning, 'approved': is_approved}
120
+ if is_approved and not mastery_vector.is_mastered(topic_id):
121
+ entry['score'] = self.ranker.score(topic_id, mastery_vector)
122
+ approved.append(entry)
123
+ elif not is_approved: vetoed.append(entry)
124
+ approved.sort(key=lambda x: x['score'], reverse=True)
125
+ return {'top_recommendations': approved[:self.top_n], 'total_approved': len(approved), 'total_vetoed': len(vetoed), 'vetoed_sample': vetoed[:5], 'prerequisite_violation_rate': round(len(vetoed)/max(len(list(self.graph.nodes)),1),3)}
126
+
127
+ ACTIVITY_TO_MATH = {'oucontent':'algebraic_expressions','forumng':'statistics_basic','homepage':'whole_numbers','subpage':'plane_shapes','resource':'indices','url':'number_bases','ouwiki':'proportion_variation','glossary':'algebraic_factorization','quiz':'quadratic_equations'}
128
+ ACTIVITY_TO_CS = {'oucontent':'programming_concepts','forumng':'ethics_technology','homepage':'computer_basics','subpage':'html_basics','resource':'networking_fundamentals','url':'internet_basics','ouwiki':'cloud_basics','glossary':'intro_databases','quiz':'python_basics'}
129
+
130
+ def run_sakt_inference(model, config, skill_seq, correct_seq, device):
131
+ max_len=config['max_seq_len']; n_skills=config['num_skills']
132
+ if len(skill_seq)>max_len: skill_seq=skill_seq[-max_len:]; correct_seq=correct_seq[-max_len:]
133
+ interactions=[s+c*n_skills for s,c in zip(skill_seq[:-1],correct_seq[:-1])]
134
+ target_skills=skill_seq[1:]
135
+ seq_len=len(interactions); pad_len=max_len-seq_len
136
+ interactions=[0]*pad_len+interactions; target_skills=[0]*pad_len+target_skills; mask=[False]*pad_len+[True]*seq_len
137
+ with torch.no_grad():
138
+ logits=model(torch.LongTensor([interactions]).to(device),torch.LongTensor([target_skills]).to(device),torch.BoolTensor([mask]).to(device))
139
+ probs=torch.sigmoid(logits).squeeze(0)
140
+ mastery={}; real_probs=probs[torch.BoolTensor(mask)].cpu().numpy(); real_skills=target_skills[pad_len:]
141
+ for skill_id,prob in zip(real_skills,real_probs): mastery[int(skill_id)]=float(prob)
142
+ return mastery
143
+
144
+ def build_mastery_vector(skill_probs, graph, skill_encoder_df, domain, threshold):
145
+ mv=MasteryVector(graph,threshold); mapping=ACTIVITY_TO_MATH if domain=='math' else ACTIVITY_TO_CS
146
+ topic_scores={}
147
+ for skill_id,prob in skill_probs.items():
148
+ row=skill_encoder_df[skill_encoder_df['skill_id']==skill_id]
149
+ if row.empty: continue
150
+ act=row['activity_type'].values[0] if 'activity_type' in row.columns else None
151
+ topic_id=mapping.get(act) if act else None
152
+ if topic_id: topic_scores[topic_id]=max(topic_scores.get(topic_id,0.0),prob)
153
+ for topic_id,score in topic_scores.items(): mv.update(topic_id,score)
154
+ return mv
155
+
156
+ def main():
157
+ model, config, device = load_model()
158
+ math_graph, cs_graph = load_knowledge_maps()
159
+ skill_encoder = load_skill_encoder()
160
+ st.title('🧠 Logic Engine')
161
+ st.subheader('Domain-Agnostic Constraint-Aware Learning Recommender')
162
+ st.markdown('---')
163
+ st.sidebar.title('⚙️ Configuration')
164
+ domain = st.sidebar.selectbox('Select Domain', ['Mathematics', 'CS Fundamentals'])
165
+ threshold = st.sidebar.slider('Mastery Threshold', 0.50, 0.90, 0.70, 0.05)
166
+ top_n = st.sidebar.slider('Top N Recommendations', 3, 10, 5)
167
+ graph = math_graph if domain=='Mathematics' else cs_graph
168
+ domain_key = 'math' if domain=='Mathematics' else 'cs'
169
+ pipeline = LearningRecommendationPipeline(graph, threshold, top_n)
170
+ st.sidebar.markdown('---')
171
+ st.sidebar.markdown('**About**')
172
+ st.sidebar.markdown('SAKT-based knowledge tracing with DAG prerequisite constraints.')
173
+ tab1, tab2, tab3 = st.tabs(['🎯 Get Recommendations','🗺️ Knowledge Map','���� Diagnostics'])
174
+ with tab1:
175
+ st.header('Learner Profile')
176
+ mode = st.radio('Input Mode', ['Manual Mastery Input','Simulate Student Sequence'], horizontal=True)
177
+ mastery_vector = MasteryVector(graph, threshold)
178
+ if mode=='Manual Mastery Input':
179
+ st.markdown('Set your current mastery level for each topic:')
180
+ cols=st.columns(2); nodes=list(graph.nodes)
181
+ for i,node in enumerate(nodes):
182
+ label=graph.nodes[node].get('label',node); level=graph.nodes[node].get('level','')
183
+ val=cols[i%2].slider(f'{label} ({level})',0.0,1.0,0.0,0.05,key=f'mastery_{node}')
184
+ mastery_vector.update(node,val)
185
+ else:
186
+ seq_length=st.slider('Sequence Length',10,200,50)
187
+ seed=st.number_input('Student Seed',1,1000,42,1)
188
+ np.random.seed(int(seed))
189
+ sim_skills=np.random.randint(0,config['num_skills'],seq_length).tolist()
190
+ sim_corrects=np.random.randint(0,2,seq_length).tolist()
191
+ skill_probs=run_sakt_inference(model,config,sim_skills,sim_corrects,device)
192
+ mastery_vector=build_mastery_vector(skill_probs,graph,skill_encoder,domain_key,threshold)
193
+ st.success(f'SAKT inference complete — {len(skill_probs)} skill predictions generated')
194
+ if st.button('🚀 Generate Recommendations', type='primary'):
195
+ output=pipeline.run(mastery_vector)
196
+ summary=mastery_vector.get_mastery_summary()
197
+ col1,col2,col3,col4=st.columns(4)
198
+ col1.metric('Topics Mastered',f"{summary['mastered']} / {summary['total_topics']}")
199
+ col2.metric('Mastery Rate',f"{summary['mastery_rate']:.1%}")
200
+ col3.metric('Approved Topics',output['total_approved'])
201
+ col4.metric('Violation Rate',f"{output['prerequisite_violation_rate']:.1%}")
202
+ st.markdown('---')
203
+ st.subheader(f'Top {top_n} Recommendations')
204
+ if not output['top_recommendations']: st.warning('No recommendations — adjust mastery or lower threshold.')
205
+ else:
206
+ for i,rec in enumerate(output['top_recommendations'],1):
207
+ with st.expander(f"{i}. {rec['topic_label']} — Score: {rec['score']} | Mastery: {rec['mastery']:.1%}", expanded=(i<=3)):
208
+ st.markdown(f"**Reasoning:** {rec['reasoning']}")
209
+ st.progress(rec['mastery'])
210
+ if output['vetoed_sample']:
211
+ st.markdown('---'); st.subheader('⛔ Sample Vetoed Topics')
212
+ for rec in output['vetoed_sample']:
213
+ with st.expander(f"✗ {rec['topic_label']}"):
214
+ st.markdown(f"**Reason:** {rec['reasoning']}")
215
+ with tab2:
216
+ st.header(f'{domain} Knowledge Map')
217
+ st.markdown(f"**{graph.number_of_nodes()} topics** | **{graph.number_of_edges()} prerequisite relationships**")
218
+ rows=[]
219
+ for node in graph.nodes:
220
+ label=graph.nodes[node].get('label',node); level=graph.nodes[node].get('level',''); term=graph.nodes[node].get('term','')
221
+ prereqs=[graph.nodes[p].get('label',p) for p in graph.predecessors(node)]
222
+ rows.append({'Topic':label,'Level':level,'Term':term,'Prerequisites':', '.join(prereqs) if prereqs else 'None (Foundational)'})
223
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
224
+ longest=nx.dag_longest_path(graph)
225
+ st.markdown('**Longest prerequisite chain:**')
226
+ st.markdown(' → '.join([graph.nodes[n].get('label',n) for n in longest]))
227
+ with tab3:
228
+ st.header('System Diagnostics')
229
+ col1,col2=st.columns(2)
230
+ with col1: st.subheader('Model Configuration'); st.json(config)
231
+ with col2:
232
+ st.subheader('DAG Statistics')
233
+ st.json({'domain':domain,'nodes':graph.number_of_nodes(),'edges':graph.number_of_edges(),'is_valid_dag':nx.is_directed_acyclic_graph(graph),'longest_path':len(nx.dag_longest_path(graph))})
234
+ st.subheader('Domain Switching')
235
+ dcol1,dcol2=st.columns(2)
236
+ with dcol1: st.metric('Math DAG',f'{math_graph.number_of_nodes()} topics')
237
+ with dcol2: st.metric('CS DAG',f'{cs_graph.number_of_nodes()} topics')
238
+
239
+ if __name__ == '__main__':
240
+ main()