contextflow-rl / app /agents /doubt_predictor.py
namish10's picture
Upload app/agents/doubt_predictor.py with huggingface_hub
5e38446 verified
"""
Reinforcement Learning Doubt Prediction Agent
This agent predicts what doubts a user will have BEFORE they occur,
using:
- User's learning history
- Current topic complexity
- Behavioral signals (eye tracking, hesitation, scroll patterns)
- Similar users' learning patterns
- Topic dependency graphs
Based on Deep Q-Learning with attention mechanism
"""
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
import json
@dataclass
class LearningState:
"""Represents the current learning state"""
topic: str
subtopic: str
progress_percentage: float
time_spent_seconds: int
confusion_signals: float
eye_tracking_confidence: float
scroll_reversals: int
selection_count: int
previous_doubts_count: int
mastery_level: float
difficulty_rating: float
time_of_day: int
streak_days: int
@dataclass
class DoubtPrediction:
"""Predicted doubt with confidence"""
predicted_doubt: str
confidence: float
suggested_explanation: str
related_concepts: List[str]
priority: int
estimated_resolution_time: int
prerequisite_topics: List[str]
@dataclass
class RLPolicy:
"""RL Policy network (simplified)"""
state_dim: int = 12
action_dim: int = 100
learning_rate: float = 0.001
gamma: float = 0.95
epsilon: float = 1.0
epsilon_decay: float = 0.995
epsilon_min: float = 0.01
q_table: Dict[str, np.ndarray] = field(default_factory=dict)
state_mean: np.ndarray = None
state_std: np.ndarray = None
class DoubtPredictorAgent:
"""
RL-based agent that predicts user doubts before they occur.
Uses a Deep Q-Network inspired architecture with:
- State encoding from learning signals
- Attention mechanism for topic relationships
- Experience replay for learning
- Progressive prediction confidence
"""
def __init__(self, user_id: str, config: Optional[Dict] = None):
self.user_id = user_id
self.config = config or {}
self.policy = RLPolicy()
self.experience_buffer = []
self.max_buffer_size = 1000
self.topic_embeddings = {}
self.concept_graph = {}
self.user_preferences = {}
self._initialize_topic_knowledge()
def _initialize_topic_knowledge(self):
"""Initialize base topic relationships"""
self.concept_graph = {
'python': ['variables', 'functions', 'classes', 'loops', 'conditionals', 'data_structures'],
'machine_learning': ['linear_regression', 'classification', 'neural_networks', 'optimization', 'feature_engineering'],
'deep_learning': ['perceptrons', 'backpropagation', 'convolutional_nets', 'recurrent_nets', 'transformers', 'attention'],
'statistics': ['probability', 'distributions', 'hypothesis_testing', 'regression', 'bayesian'],
'calculus': ['derivatives', 'integrals', 'limits', 'series', 'multivariable'],
'linear_algebra': ['vectors', 'matrices', 'eigenvalues', 'transformations', 'decompositions']
}
self.doubt_templates = {
'variables': [
"What is the difference between mutable and immutable types?",
"How does variable scope work in nested functions?",
"When should I use global vs local variables?"
],
'functions': [
"What is the difference between arguments and parameters?",
"How do *args and **kwargs work?",
"When should I use lambda functions?"
],
'classes': [
"What is the difference between class and instance attributes?",
"How does inheritance work with multiple inheritance?",
"What are abstract base classes and when to use them?"
],
'loops': [
"When should I use for vs while loops?",
"How do list comprehensions replace loops?",
"What is the difference between break and continue?"
],
'data_structures': [
"When should I use lists vs dictionaries?",
"What is the time complexity of dictionary operations?",
"How do sets differ from lists in performance?"
],
'linear_regression': [
"What is the cost function and how is it optimized?",
"How do I handle multicollinearity?",
"What are the assumptions of linear regression?"
],
'neural_networks': [
"What is the role of activation functions?",
"How does backpropagation compute gradients?",
"What is the vanishing gradient problem?"
],
'transformers': [
"How does self-attention work?",
"What is the difference between encoder and decoder?",
"Why is positional encoding needed?"
]
}
def get_current_state(self, learning_context: Dict) -> LearningState:
"""Extract current learning state from context"""
return LearningState(
topic=learning_context.get('topic', 'unknown'),
subtopic=learning_context.get('subtopic', 'unknown'),
progress_percentage=learning_context.get('progress', 0.0),
time_spent_seconds=learning_context.get('time_spent', 0),
confusion_signals=learning_context.get('confusion_score', 0.0),
eye_tracking_confidence=learning_context.get('eye_confidence', 0.0),
scroll_reversals=learning_context.get('scroll_reversals', 0),
selection_count=learning_context.get('selections', 0),
previous_doubts_count=learning_context.get('prev_doubts', 0),
mastery_level=learning_context.get('mastery', 0.0),
difficulty_rating=learning_context.get('difficulty', 0.5),
time_of_day=datetime.now().hour,
streak_days=learning_context.get('streak', 0)
)
def state_to_vector(self, state: LearningState) -> np.ndarray:
"""Convert state to feature vector"""
features = [
self._topic_to_feature(state.topic),
self._topic_to_feature(state.subtopic),
state.progress_percentage,
np.log1p(state.time_spent_seconds) / 10,
state.confusion_signals,
state.eye_tracking_confidence,
np.tanh(state.scroll_reversals / 10),
np.tanh(state.selection_count / 20),
np.tanh(state.previous_doubts_count / 50),
state.mastery_level,
state.difficulty_rating,
np.sin(2 * np.pi * state.time_of_day / 24),
np.cos(2 * np.pi * state.time_of_day / 24),
np.tanh(state.streak_days / 30)
]
return np.array(features, dtype=np.float32)
def _topic_to_feature(self, topic: str) -> float:
"""Convert topic to numerical feature"""
topic_lower = topic.lower().replace(' ', '_')
topic_order = [
'variables', 'functions', 'classes', 'loops', 'conditionals', 'data_structures',
'probability', 'distributions', 'derivatives', 'integrals', 'vectors', 'matrices',
'linear_regression', 'classification', 'neural_networks', 'optimization',
'convolutional_nets', 'recurrent_nets', 'transformers', 'attention'
]
if topic_lower in topic_order:
return topic_order.index(topic_lower) / len(topic_order)
return 0.5
def predict_doubts(
self,
learning_context: Dict,
top_k: int = 5,
gesture_influence: Optional[float] = None
) -> List[DoubtPrediction]:
"""
Predict likely doubts for current learning context.
Uses RL policy to estimate which doubts are most likely,
based on current state and historical patterns.
Args:
learning_context: Current learning state
top_k: Number of predictions to return
gesture_influence: Optional gesture-based signal (0-1) that increases doubt confidence
"""
state = self.get_current_state(learning_context)
state_vec = self.state_to_vector(state)
predictions = []
related_concepts = self._get_related_concepts(state.topic, state.subtopic)
for concept in related_concepts:
if concept not in self.doubt_templates:
continue
templates = self.doubt_templates[concept]
for template in templates:
confidence = self._calculate_doubt_confidence(
state, concept, template, gesture_influence
)
if confidence > 0.3:
prerequisite = self._get_prerequisites(concept)
prediction = DoubtPrediction(
predicted_doubt=template,
confidence=confidence,
suggested_explanation=self._generate_explanation_hint(concept, template),
related_concepts=self._get_related_concepts(concept, ''),
priority=self._calculate_priority(state, confidence),
estimated_resolution_time=self._estimate_time(concept),
prerequisite_topics=prerequisite
)
predictions.append(prediction)
predictions.sort(key=lambda x: x.priority, reverse=True)
return predictions[:top_k]
def _calculate_doubt_confidence(
self,
state: LearningState,
concept: str,
template: str,
gesture_influence: Optional[float] = None
) -> float:
"""Calculate confidence that user will have this doubt"""
base_confidence = 0.5
if state.confusion_signals > 0.7:
base_confidence += 0.2
if state.eye_tracking_confidence < 0.5:
base_confidence += 0.15
if state.scroll_reversals > 5:
base_confidence += 0.1
if concept in self.concept_graph.get(state.topic.lower(), []):
base_confidence += 0.1
if state.difficulty_rating > 0.7:
base_confidence += 0.15
if state.mastery_level < 0.3:
base_confidence += 0.1
if gesture_influence is not None and gesture_influence > 0.5:
base_confidence += (gesture_influence - 0.5) * 0.4
difficulty_penalty = state.difficulty_rating * 0.1
base_confidence -= difficulty_penalty
return min(max(base_confidence, 0.0), 1.0)
def _get_related_concepts(self, topic: str, subtopic: str) -> List[str]:
"""Get concepts related to current topic"""
topic_lower = topic.lower().replace(' ', '_')
subtopic_lower = subtopic.lower().replace(' ', '_')
related = []
if topic_lower in self.concept_graph:
related.extend(self.concept_graph[topic_lower])
if subtopic_lower in self.concept_graph:
related.extend(self.concept_graph[subtopic_lower])
for t, concepts in self.concept_graph.items():
for c in concepts:
if t == topic_lower or c == subtopic_lower:
related.extend(concepts)
return list(set(related))[:10]
def _get_prerequisites(self, concept: str) -> List[str]:
"""Get prerequisite concepts that should be understood first"""
prereq_map = {
'neural_networks': ['linear_regression', 'calculus', 'linear_algebra'],
'transformers': ['neural_networks', 'attention', 'linear_algebra'],
'convolutional_nets': ['neural_networks', 'linear_algebra'],
'backpropagation': ['derivatives', 'chain_rule'],
'optimization': ['calculus', 'derivatives'],
'classification': ['probability', 'linear_regression'],
}
return prereq_map.get(concept, [])
def _generate_explanation_hint(self, concept: str, template: str) -> str:
"""Generate a quick explanation hint"""
hints = {
'variables': 'Variables store data values that can be changed or accessed later.',
'functions': 'Functions are reusable blocks of code that perform specific tasks.',
'classes': 'Classes define blueprints for creating objects with attributes and methods.',
'neural_networks': 'Neural networks learn patterns through weighted connections between neurons.',
'transformers': 'Transformers use self-attention to process sequences in parallel.',
'backpropagation': 'Backpropagation calculates gradients by propagating errors backwards through the network.'
}
return hints.get(concept, 'Review the fundamental concepts before proceeding.')
def _calculate_priority(self, state: LearningState, confidence: float) -> float:
"""Calculate priority score for doubt prediction"""
priority = confidence * 0.4
priority += state.confusion_signals * 0.2
priority += (1 - state.mastery_level) * 0.2
priority += state.difficulty_rating * 0.1
priority += min(state.time_spent_seconds / 600, 1) * 0.1
return priority
def _estimate_time(self, concept: str) -> int:
"""Estimate time to resolve doubt in minutes"""
time_map = {
'variables': 5,
'functions': 10,
'classes': 15,
'loops': 8,
'data_structures': 20,
'linear_regression': 25,
'neural_networks': 30,
'transformers': 45,
'backpropagation': 35
}
return time_map.get(concept, 15)
def update_policy(
self,
state: LearningState,
predicted_doubt: str,
actual_doubt: str,
reward: float
):
"""
Update RL policy based on whether prediction was correct.
Positive reward if predicted doubt matches actual doubt.
Negative reward for false positives.
"""
state_key = self._state_to_key(state)
if state_key not in self.policy.q_table:
self.policy.q_table[state_key] = np.zeros(self.policy.action_dim)
action_idx = self._doubt_to_action(predicted_doubt)
current_q = self.policy.q_table[state_key][action_idx]
max_next_q = np.max(self.policy.q_table.get(state_key, [0]))
new_q = current_q + self.policy.learning_rate * (
reward + self.policy.gamma * max_next_q - current_q
)
self.policy.q_table[state_key][action_idx] = new_q
self.experience_buffer.append({
'state': state,
'predicted': predicted_doubt,
'actual': actual_doubt,
'reward': reward,
'timestamp': datetime.now().isoformat()
})
if len(self.experience_buffer) > self.max_buffer_size:
self.experience_buffer.pop(0)
if self.policy.epsilon > self.policy.epsilon_min:
self.policy.epsilon *= self.policy.epsilon_decay
def _state_to_key(self, state: LearningState) -> str:
"""Convert state to hashable key"""
return f"{state.topic}_{state.subtopic}_{int(state.progress_percentage * 10)}_{int(state.confusion_signals * 10)}"
def _doubt_to_action(self, doubt: str) -> int:
"""Convert doubt to action index"""
doubt_hash = hash(doubt.lower().strip())
return abs(doubt_hash) % self.policy.action_dim
def get_learning_recommendations(self, learning_context: Dict) -> Dict[str, Any]:
"""Get personalized learning recommendations based on predictions"""
predictions = self.predict_doubts(learning_context, top_k=3)
state = self.get_current_state(learning_context)
recommendations = {
'next_topics': [],
'review_topics': [],
'practice_exercises': [],
'estimated_difficulty': state.difficulty_rating,
'predicted_struggles': [p.predicted_doubt for p in predictions],
'confidence_boosters': [],
'optimal_break_time': self._suggest_break_time(learning_context)
}
if state.confusion_signals > 0.7:
recommendations['next_topics'] = self._get_prerequisites(state.topic)
recommendations['confidence_boosters'].append('Review prerequisite concepts')
if state.mastery_level > 0.8:
recommendations['next_topics'].append(state.topic)
recommendations['practice_exercises'].append(f"Advanced {state.topic} project")
if state.time_spent_seconds > 1800:
recommendations['suggest_break'] = True
recommendations['break_duration'] = 5
return recommendations
def _suggest_break_time(self, context: Dict) -> Optional[str]:
"""Suggest optimal break time based on learning patterns"""
if context.get('confusion_score', 0) > 0.6:
return "Take a 5-minute break to process information"
elif context.get('time_spent', 0) > 2400:
return "Take a longer 15-minute break"
return None
def export_model(self) -> Dict:
"""Export model state for persistence"""
return {
'user_id': self.user_id,
'q_table_size': len(self.policy.q_table),
'experience_buffer_size': len(self.experience_buffer),
'epsilon': self.policy.epsilon,
'concepts': list(self.concept_graph.keys()),
'doubt_templates': list(self.doubt_templates.keys())
}
def import_model(self, model_data: Dict):
"""Import model state from persistence"""
if 'concepts' in model_data:
for concept in model_data['concepts']:
if concept not in self.concept_graph:
self.concept_graph[concept] = []