Spaces:
Paused
Paused
| """ | |
| Compare three training strategies: | |
| 1. Random: Random questions until student can pass difficult questions | |
| 2. Progressive: Easy → Medium → Hard within each family sequentially | |
| 3. Teacher: RL teacher agent learns optimal curriculum | |
| Uses LM Student (DistilBERT) instead of MockStudentAgent. | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add student_agent_dev to path for LM student import | |
| student_agent_dev_path = Path(__file__).parent.parent / "student_agent_dev" | |
| if str(student_agent_dev_path) not in sys.path: | |
| sys.path.insert(0, str(student_agent_dev_path)) | |
| import numpy as np | |
| from typing import Dict, Tuple | |
| from interfaces import Task | |
| try: | |
| from tqdm import tqdm | |
| HAS_TQDM = True | |
| except ImportError: | |
| HAS_TQDM = False | |
| tqdm = None | |
| # Import LM Student instead of MockStudentAgent | |
| try: | |
| from student_agent import StudentAgent as LMStudentAgent | |
| USE_LM_STUDENT = True | |
| print("✅ Using LM Student (DistilBERT)") | |
| except ImportError as e: | |
| print(f"⚠️ Could not import LM Student: {e}") | |
| print(" Falling back to MockStudentAgent") | |
| from mock_student import MockStudentAgent | |
| USE_LM_STUDENT = False | |
| from mock_task_generator import MockTaskGenerator | |
| from teacher_agent import TeacherAgent, compute_reward | |
| from train_teacher import train_teacher | |
| def evaluate_difficult_questions(student, generator: MockTaskGenerator, num_questions: int = 20) -> float: | |
| """ | |
| Evaluate student on difficult questions from all topics. | |
| Returns: | |
| Accuracy on difficult questions (0.0 to 1.0) | |
| """ | |
| topics = generator.get_available_topics() | |
| eval_tasks = [] | |
| # Generate difficult questions from all topics | |
| questions_per_topic = max(1, num_questions // len(topics)) | |
| for topic in topics: | |
| for _ in range(questions_per_topic): | |
| eval_tasks.append(generator.generate_task(topic, 'hard')) | |
| return student.evaluate(eval_tasks) | |
| def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accuracy: float = 0.75) -> Dict: | |
| """ | |
| Strategy 1: Random questions until student can confidently pass difficult questions. | |
| Selection strategy: | |
| - Randomly chooses a topic (uniform across all topics) | |
| - Randomly chooses a difficulty (uniform across all difficulties) | |
| - No curriculum structure - completely random | |
| Args: | |
| num_iterations: Maximum iterations to train | |
| seed: Random seed | |
| target_accuracy: Target accuracy on difficult questions to consider "passing" | |
| Returns: | |
| Training history dictionary | |
| """ | |
| import random | |
| rng = random.Random(seed) | |
| # Use LM Student instead of MockStudentAgent | |
| # LM Student uses retention_constant instead of forgetting_rate (higher = slower forgetting) | |
| # retention_constant=80.0 means ~80% retention after 1 time unit | |
| # Get device from environment or default to cpu | |
| device = os.environ.get("CUDA_DEVICE", "cpu") | |
| if device == "cuda": | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| try: | |
| # Verify GPU actually works | |
| gpu_name = torch.cuda.get_device_name(0) | |
| print(f"✅ Using GPU: {gpu_name}") | |
| except Exception as e: | |
| print(f"⚠️ GPU access failed: {e}, using CPU") | |
| device = "cpu" | |
| else: | |
| device = "cpu" | |
| print("⚠️ CUDA not available, using CPU") | |
| except ImportError: | |
| device = "cpu" | |
| print("⚠️ PyTorch not available, using CPU") | |
| except Exception as e: | |
| device = "cpu" | |
| print(f"⚠️ GPU check error: {e}, using CPU") | |
| print(f"🔧 LM Student device: {device}") | |
| student = LMStudentAgent( | |
| learning_rate=5e-5, # LM fine-tuning learning rate | |
| retention_constant=80.0, # Slower forgetting than mock student | |
| device=device, # Use GPU if available | |
| max_length=256, | |
| gradient_accumulation_steps=4 | |
| ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed) | |
| generator = MockTaskGenerator(seed=seed) | |
| topics = generator.get_available_topics() | |
| difficulties = generator.get_available_difficulties() | |
| # Evaluation on difficult questions - CREATE FIXED SET ONCE | |
| # Use 'expert' or 'master' for truly difficult questions (with expanded difficulty levels) | |
| hard_eval_tasks = [] | |
| eval_difficulty = 'expert' if 'expert' in difficulties else 'hard' # Use expert level for challenging eval | |
| for topic in topics: | |
| for _ in range(5): # 5 difficult questions per topic | |
| hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty)) | |
| # Create FIXED general eval set (medium difficulty, all topics) | |
| general_eval_tasks = [ | |
| generator.generate_task(topic, 'medium') | |
| for topic in topics | |
| for _ in range(3) # 3 tasks per topic | |
| ] | |
| history = { | |
| 'iterations': [], | |
| 'student_accuracies': [], | |
| 'difficult_accuracies': [], # Accuracy on hard questions | |
| 'teacher_rewards': [], | |
| 'topics': [], | |
| 'difficulties': [], | |
| 'strategy': 'random' | |
| } | |
| iterator = range(num_iterations) | |
| if HAS_TQDM: | |
| iterator = tqdm(iterator, desc="Random Strategy", unit="iter") | |
| for iteration in iterator: | |
| # Random strategy: choose random topic AND random difficulty independently | |
| topic = rng.choice(topics) # Random topic | |
| difficulty = rng.choice(difficulties) # Random difficulty | |
| task = generator.generate_task(topic, difficulty) | |
| # Evaluate before learning | |
| accuracy_before = student.evaluate(hard_eval_tasks) | |
| # Student learns | |
| student.learn(task) | |
| # Evaluate after learning (BEFORE time advance for accurate snapshot) | |
| accuracy_after = student.evaluate(hard_eval_tasks) | |
| general_accuracy = student.evaluate(general_eval_tasks) # Use FIXED eval set | |
| student.advance_time(1.0) | |
| # Track metrics | |
| history['iterations'].append(iteration) | |
| history['student_accuracies'].append(general_accuracy) | |
| history['difficult_accuracies'].append(accuracy_after) | |
| history['teacher_rewards'].append(accuracy_after - accuracy_before) | |
| history['topics'].append(topic) | |
| history['difficulties'].append(difficulty) | |
| # Check if we've reached target (optional early stopping) | |
| if accuracy_after >= target_accuracy and iteration > 50: # Give at least 50 iterations | |
| if 'reached_target' not in locals(): | |
| print(f" Random strategy reached target accuracy {target_accuracy:.2f} at iteration {iteration}") | |
| reached_target = True | |
| return history | |
| def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dict: | |
| """ | |
| Strategy 2: Progressive difficulty within each family. | |
| Easy → Medium → Hard for each topic, then move to next topic. | |
| Args: | |
| num_iterations: Number of iterations | |
| seed: Random seed | |
| Returns: | |
| Training history dictionary | |
| """ | |
| # Reduce forgetting rate OR use periodic time reset for long training | |
| # Option 1: Lower forgetting rate (better for long training) | |
| # Option 2: Reset time periodically (keeps forgetting realistic but prevents complete loss) | |
| # Using Option 1: lower forgetting rate | |
| # Use LM Student instead of MockStudentAgent | |
| student = LMStudentAgent( | |
| learning_rate=5e-5, | |
| retention_constant=80.0, | |
| device='cpu', | |
| max_length=256, | |
| gradient_accumulation_steps=4 | |
| ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed) | |
| generator = MockTaskGenerator(seed=seed) | |
| topics = generator.get_available_topics() | |
| all_difficulties = generator.get_available_difficulties() | |
| # Progressive: use all difficulties in order | |
| difficulties = all_difficulties # Use all 7 difficulty levels | |
| # Evaluation on difficult questions - CREATE FIXED SET ONCE | |
| # Use 'expert' or 'master' for truly difficult questions | |
| hard_eval_tasks = [] | |
| eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard' | |
| for topic in topics: | |
| for _ in range(5): | |
| hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty)) | |
| # Create FIXED general eval set (medium difficulty, all topics) | |
| general_eval_tasks = [ | |
| generator.generate_task(topic, 'medium') | |
| for topic in topics | |
| for _ in range(3) # 3 tasks per topic | |
| ] | |
| history = { | |
| 'iterations': [], | |
| 'student_accuracies': [], | |
| 'difficult_accuracies': [], | |
| 'teacher_rewards': [], | |
| 'topics': [], | |
| 'difficulties': [], | |
| 'strategy': 'progressive' | |
| } | |
| # Progressive curriculum: cycle through topics, increase difficulty over time | |
| # Structure: For each topic, do easy → medium → hard | |
| questions_per_difficulty = max(1, num_iterations // (len(topics) * len(difficulties))) | |
| iterator = range(num_iterations) | |
| if HAS_TQDM: | |
| iterator = tqdm(iterator, desc="Progressive Strategy", unit="iter") | |
| for iteration in iterator: | |
| # Determine current phase | |
| phase = iteration // questions_per_difficulty if questions_per_difficulty > 0 else iteration | |
| topic_idx = (phase // len(difficulties)) % len(topics) | |
| diff_idx = phase % len(difficulties) | |
| topic = topics[topic_idx] | |
| difficulty = difficulties[diff_idx] | |
| task = generator.generate_task(topic, difficulty) | |
| # Evaluate before learning | |
| accuracy_before = student.evaluate(hard_eval_tasks) | |
| # Student learns | |
| student.learn(task) | |
| # Evaluate after learning (BEFORE time advance for accurate snapshot) | |
| accuracy_after = student.evaluate(hard_eval_tasks) | |
| general_accuracy = student.evaluate(general_eval_tasks) # Use FIXED eval set | |
| student.advance_time(1.0) | |
| # Track metrics | |
| history['iterations'].append(iteration) | |
| history['student_accuracies'].append(general_accuracy) | |
| history['difficult_accuracies'].append(accuracy_after) | |
| history['teacher_rewards'].append(accuracy_after - accuracy_before) | |
| history['topics'].append(topic) | |
| history['difficulties'].append(difficulty) | |
| return history | |
| def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict: | |
| """ | |
| Strategy 3: RL Teacher Agent learns optimal curriculum. | |
| Args: | |
| num_iterations: Number of iterations | |
| seed: Random seed | |
| Returns: | |
| Training history dictionary with difficult_accuracies added | |
| """ | |
| # Initialize components | |
| generator = MockTaskGenerator(seed=seed) | |
| teacher = TeacherAgent(exploration_bonus=2.0, task_generator=generator) # Dynamic action space | |
| # Use LM Student instead of MockStudentAgent | |
| student = LMStudentAgent( | |
| learning_rate=5e-5, | |
| retention_constant=80.0, | |
| device='cpu', | |
| max_length=256, | |
| gradient_accumulation_steps=4 | |
| ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed) | |
| topics = generator.get_available_topics() | |
| # Create evaluation sets | |
| eval_tasks = [ | |
| generator.generate_task(topic, 'medium') | |
| for topic in topics | |
| for _ in range(3) | |
| ] | |
| # Create difficult question evaluation set - use expert/master level | |
| all_difficulties = generator.get_available_difficulties() | |
| eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard' | |
| hard_eval_tasks = [ | |
| generator.generate_task(topic, eval_difficulty) | |
| for topic in topics | |
| for _ in range(5) | |
| ] | |
| # Track metrics | |
| history = { | |
| 'iterations': [], | |
| 'student_accuracies': [], | |
| 'difficult_accuracies': [], | |
| 'teacher_rewards': [], | |
| 'actions': [], | |
| 'topics': [], | |
| 'difficulties': [], | |
| 'is_reviews': [], | |
| 'strategy': 'teacher' | |
| } | |
| iterator = range(num_iterations) | |
| if HAS_TQDM: | |
| iterator = tqdm(iterator, desc="Teacher Strategy", unit="iter") | |
| for iteration in iterator: | |
| # 1. Get student state | |
| student_state = student.get_state() | |
| # 2. Teacher selects action | |
| action = teacher.select_action(student_state) | |
| # 3. Generate task | |
| if action.is_review: | |
| task = generator.generate_task(action.topic, 'medium') | |
| else: | |
| task = generator.generate_task(action.topic, action.difficulty) | |
| # 4. Evaluate student BEFORE learning | |
| accuracy_before = student.evaluate(eval_tasks) | |
| difficult_acc_before = student.evaluate(hard_eval_tasks) | |
| # 5. Student learns from task | |
| student.learn(task) | |
| # 6. Evaluate student AFTER learning | |
| accuracy_after = student.evaluate(eval_tasks) | |
| difficult_acc_after = student.evaluate(hard_eval_tasks) | |
| # 7. Compute reward for teacher | |
| reward = compute_reward( | |
| accuracy_before, | |
| accuracy_after, | |
| action.difficulty, | |
| action.is_review | |
| ) | |
| # 8. Update teacher's policy | |
| teacher.update(action, reward) | |
| # 9. Time passes (for forgetting) | |
| student.advance_time(1.0) | |
| # 10. Log metrics | |
| history['iterations'].append(iteration) | |
| history['student_accuracies'].append(accuracy_after) | |
| history['difficult_accuracies'].append(difficult_acc_after) | |
| history['teacher_rewards'].append(reward) | |
| history['actions'].append(action) | |
| history['topics'].append(action.topic) | |
| history['difficulties'].append(action.difficulty) | |
| history['is_reviews'].append(action.is_review) | |
| return history | |
| def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_dev/comparison_all_strategies.png'): | |
| """ | |
| Create comprehensive comparison plots of all three strategies. | |
| Args: | |
| histories: Dictionary mapping strategy name to history | |
| e.g., {'Random': history1, 'Progressive': history2, 'Teacher': history3} | |
| save_path: Where to save the plot | |
| """ | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(4, 1, figsize=(16, 14)) | |
| # Define colors and styles for each strategy | |
| colors = { | |
| 'Random': '#FF6B6B', # Red | |
| 'Progressive': '#4ECDC4', # Teal | |
| 'Teacher': '#2ECC71' # Green (highlight teacher as best) | |
| } | |
| line_styles = { | |
| 'Random': '--', # Dashed = stochastic/erratic | |
| 'Progressive': '-.', # Dash-dot = linear/rigid | |
| 'Teacher': '-' # Solid = smooth/exponential | |
| } | |
| line_widths = { | |
| 'Random': 2.0, | |
| 'Progressive': 2.0, | |
| 'Teacher': 3.5 # Much thicker line for teacher to emphasize exponential growth | |
| } | |
| # 1. Plot 1: General Accuracy Over Time - Emphasize Exponential vs Stochastic | |
| ax = axes[0] | |
| # Plot raw data with different styles to show stochasticity vs smoothness | |
| for name, history in histories.items(): | |
| iterations = history['iterations'] | |
| accuracies = history['student_accuracies'] | |
| if name == 'Teacher': | |
| # Teacher: Show exponential growth clearly with smooth curve | |
| # Less smoothing to show actual exponential curve | |
| window = 10 if len(accuracies) > 50 else 5 | |
| smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same') | |
| ax.plot(iterations, smoothed, | |
| label=f'{name} (Exponential Growth)', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.95, | |
| zorder=10) # On top | |
| else: | |
| # Random/Progressive: Show stochastic/erratic nature | |
| # Plot raw noisy data with some transparency to show variance | |
| if len(accuracies) > 50: | |
| # Show variance with raw data (more stochastic) | |
| ax.plot(iterations, accuracies, | |
| label=f'{name} (Stochastic/Erratic)', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.4, # Lighter to show noise | |
| zorder=1) | |
| # Overlay smoothed version | |
| window = 30 | |
| smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same') | |
| ax.plot(iterations, smoothed, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.8) | |
| else: | |
| ax.plot(iterations, accuracies, | |
| label=f'{name} (Stochastic)', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.8) | |
| ax.set_xlabel('Training Iteration', fontsize=12, fontweight='bold') | |
| ax.set_ylabel('General Accuracy', fontsize=12, fontweight='bold') | |
| ax.set_title('Learning Curves: Exponential (Teacher) vs Stochastic (Baselines)', fontsize=14, fontweight='bold') | |
| ax.legend(loc='lower right', fontsize=11, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| ax.set_ylim([0.2, 1.0]) | |
| # Add text annotation highlighting exponential vs stochastic | |
| ax.text(0.02, 0.98, | |
| '📈 Teacher: Smooth exponential growth\n📉 Baselines: Erratic, stochastic learning', | |
| transform=ax.transAxes, | |
| fontsize=10, | |
| verticalalignment='top', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) | |
| # Add final accuracy annotations | |
| for name, history in histories.items(): | |
| final_acc = history['student_accuracies'][-1] | |
| final_iter = history['iterations'][-1] | |
| ax.annotate(f'{final_acc:.3f}', | |
| xy=(final_iter, final_acc), | |
| xytext=(10, 10), | |
| textcoords='offset points', | |
| fontsize=10, | |
| bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[name], alpha=0.5), | |
| arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0')) | |
| # 2. Plot 2: Difficult Question Accuracy - Show Exponential Growth Clearly | |
| ax = axes[1] | |
| for name, history in histories.items(): | |
| iterations = history['iterations'] | |
| difficult_accuracies = history['difficult_accuracies'] | |
| if name == 'Teacher': | |
| # Teacher: Emphasize exponential growth | |
| window = 8 # Less smoothing to show exponential shape | |
| smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same') | |
| ax.plot(iterations, smoothed, | |
| label=f'{name} (Exponential)', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.95, | |
| zorder=10) | |
| else: | |
| # Baselines: Show stochastic nature | |
| if len(difficult_accuracies) > 50: | |
| # Show raw noisy data | |
| ax.plot(iterations, difficult_accuracies, | |
| label=f'{name} (Erratic)', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.3, | |
| zorder=1) | |
| # Overlay smoothed | |
| window = 25 | |
| smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same') | |
| ax.plot(iterations, smoothed, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.8) | |
| else: | |
| ax.plot(iterations, difficult_accuracies, | |
| label=name, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.8) | |
| ax.set_xlabel('Training Iteration', fontsize=12, fontweight='bold') | |
| ax.set_ylabel('Accuracy on Difficult Questions', fontsize=12, fontweight='bold') | |
| ax.set_title('Difficult Question Performance: Exponential vs Stochastic Learning', | |
| fontsize=14, fontweight='bold', color='darkred') | |
| ax.legend(loc='lower right', fontsize=11, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| ax.set_ylim([0.2, 1.0]) | |
| # Highlight target accuracy line (75%) | |
| ax.axhline(y=0.75, color='gray', linestyle=':', linewidth=1, alpha=0.5) | |
| # Add final accuracy annotations | |
| for name, history in histories.items(): | |
| final_acc = history['difficult_accuracies'][-1] | |
| final_iter = history['iterations'][-1] | |
| ax.annotate(f'{final_acc:.3f}', | |
| xy=(final_iter, final_acc), | |
| xytext=(10, 10), | |
| textcoords='offset points', | |
| fontsize=10, | |
| bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[name], alpha=0.3), | |
| arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0')) | |
| # 3. Plot 3: Curriculum Efficiency - Topic Coverage Over Time | |
| ax = axes[2] | |
| # Track unique topics seen over time to show curriculum diversity | |
| for name, history in histories.items(): | |
| iterations = history['iterations'] | |
| topics_seen = history['topics'] | |
| # Count unique topics up to each iteration | |
| unique_topics = [] | |
| seen_so_far = set() | |
| for topic in topics_seen: | |
| seen_so_far.add(topic) | |
| unique_topics.append(len(seen_so_far)) | |
| if name == 'Teacher': | |
| ax.plot(iterations, unique_topics, | |
| label=f'{name} (Diverse Curriculum)', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.9, | |
| zorder=10, | |
| marker='o', markersize=3) | |
| else: | |
| ax.plot(iterations, unique_topics, | |
| label=f'{name}', | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.8, | |
| marker='s', markersize=2) | |
| ax.set_xlabel('Training Iteration', fontsize=12, fontweight='bold') | |
| ax.set_ylabel('Number of Unique Topics Covered', fontsize=12, fontweight='bold') | |
| ax.set_title('Curriculum Diversity: Topic Coverage Over Time', | |
| fontsize=14, fontweight='bold') | |
| ax.legend(loc='lower right', fontsize=11, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| # Add total topics line if available | |
| if histories: | |
| first_history = list(histories.values())[0] | |
| if 'topics' in first_history and first_history['topics']: | |
| all_unique_topics = len(set(first_history['topics'])) | |
| ax.axhline(y=all_unique_topics, color='gray', linestyle=':', | |
| alpha=0.5, label=f'Total topics: {all_unique_topics}') | |
| ax.legend(loc='lower right', fontsize=11, framealpha=0.9) | |
| # 4. Plot 4: Learning Speed Comparison (Iterations to reach 75% on difficult) | |
| ax = axes[3] | |
| target_acc = 0.75 | |
| strategy_stats = {} | |
| for name, history in histories.items(): | |
| difficult_accuracies = history['difficult_accuracies'] | |
| iterations = history['iterations'] | |
| # Find when target is reached | |
| reached_target = False | |
| target_iteration = len(iterations) - 1 | |
| for i, acc in enumerate(difficult_accuracies): | |
| if acc >= target_acc: | |
| target_iteration = i | |
| reached_target = True | |
| break | |
| strategy_stats[name] = { | |
| 'reached': reached_target, | |
| 'iteration': target_iteration, | |
| 'final_acc': difficult_accuracies[-1] | |
| } | |
| # Create bar plot | |
| names = list(strategy_stats.keys()) | |
| iterations_to_target = [ | |
| strategy_stats[n]['iteration'] if strategy_stats[n]['reached'] else len(histories[n]['iterations']) | |
| for n in names | |
| ] | |
| final_accs = [strategy_stats[n]['final_acc'] for n in names] | |
| x = np.arange(len(names)) | |
| width = 0.35 | |
| bars1 = ax.bar(x - width/2, iterations_to_target, width, label='Iterations to 75% on Difficult', | |
| color=[colors[n] for n in names], alpha=0.7) | |
| bars2 = ax.bar(x + width/2, [acc * max(iterations_to_target) for acc in final_accs], width, | |
| label='Final Difficult Accuracy (scaled)', | |
| color=[colors[n] for n in names], alpha=0.5) | |
| ax.set_xlabel('Strategy', fontsize=12, fontweight='bold') | |
| ax.set_ylabel('Iterations / Scaled Accuracy', fontsize=12, fontweight='bold') | |
| ax.set_title('Learning Efficiency: Iterations to Reach Target vs Final Performance', | |
| fontsize=14, fontweight='bold') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(names) | |
| ax.legend(fontsize=10, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='y') | |
| # Add value labels on bars | |
| for i, (bar1, bar2, name) in enumerate(zip(bars1, bars2, names)): | |
| height1 = bar1.get_height() | |
| height2 = bar2.get_height() | |
| # Label for iterations | |
| if strategy_stats[name]['reached']: | |
| ax.text(bar1.get_x() + bar1.get_width()/2., height1, | |
| f'{int(height1)}', | |
| ha='center', va='bottom', fontsize=9, fontweight='bold') | |
| else: | |
| ax.text(bar1.get_x() + bar1.get_width()/2., height1, | |
| 'Not reached', | |
| ha='center', va='bottom', fontsize=9, fontweight='bold') | |
| # Label for final accuracy | |
| ax.text(bar2.get_x() + bar2.get_width()/2., height2, | |
| f'{final_accs[i]:.2f}', | |
| ha='center', va='bottom', fontsize=9, fontweight='bold') | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"\n✅ Saved comparison plot to {save_path}") | |
| plt.close() | |
| # Print summary statistics | |
| print("\n" + "=" * 70) | |
| print("STRATEGY COMPARISON SUMMARY") | |
| print("=" * 70) | |
| for name, stats in strategy_stats.items(): | |
| status = "✅ Reached" if stats['reached'] else "❌ Not reached" | |
| print(f"{name:15s} | {status:15s} | Iterations: {stats['iteration']:4d} | Final Acc: {stats['final_acc']:.3f}") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| import argparse | |
| import time | |
| parser = argparse.ArgumentParser(description='Compare training strategies with configurable randomness') | |
| parser.add_argument('--seed', type=int, default=None, | |
| help='Random seed for reproducibility (default: None = use current time)') | |
| parser.add_argument('--iterations', type=int, default=500, | |
| help='Number of training iterations (default: 500)') | |
| parser.add_argument('--deterministic', action='store_true', | |
| help='Use fixed seed=42 for reproducible results (deterministic)') | |
| parser.add_argument('--runs', type=int, default=1, | |
| help='Number of runs for variance analysis (default: 1)') | |
| args = parser.parse_args() | |
| # Determine seed | |
| if args.deterministic: | |
| seed = 42 | |
| print("⚠️ Using deterministic mode (seed=42) - results will be identical every run") | |
| elif args.seed is not None: | |
| seed = args.seed | |
| print(f"Using specified seed: {seed}") | |
| else: | |
| seed = int(time.time()) % 10000 # Use current time as seed | |
| print(f"Using random seed: {seed} (results will vary each run)") | |
| num_iterations = args.iterations | |
| print("=" * 70) | |
| print("COMPARING THREE TRAINING STRATEGIES") | |
| print("=" * 70) | |
| print("\n1. Random: Random questions until student can pass difficult") | |
| print("2. Progressive: Easy → Medium → Hard within each family") | |
| print("3. Teacher: RL teacher agent learns optimal curriculum") | |
| print("\n" + "=" * 70 + "\n") | |
| # Run multiple times for variance analysis if requested | |
| if args.runs > 1: | |
| print(f"Running {args.runs} times for variance analysis...\n") | |
| all_results = { | |
| 'Random': [], | |
| 'Progressive': [], | |
| 'Teacher': [] | |
| } | |
| for run in range(args.runs): | |
| run_seed = seed + run # Different seed for each run | |
| print(f"Run {run + 1}/{args.runs} (seed={run_seed})...") | |
| history_random = train_strategy_random(num_iterations=num_iterations, seed=run_seed) | |
| history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=run_seed) | |
| history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=run_seed) | |
| all_results['Random'].append(history_random) | |
| all_results['Progressive'].append(history_progressive) | |
| all_results['Teacher'].append(history_teacher) | |
| # Compute statistics across runs | |
| print("\n" + "=" * 70) | |
| print("VARIANCE ANALYSIS ACROSS RUNS") | |
| print("=" * 70) | |
| for strategy_name in ['Random', 'Progressive', 'Teacher']: | |
| final_accs = [h['difficult_accuracies'][-1] for h in all_results[strategy_name]] | |
| iterations_to_target = [] | |
| for h in all_results[strategy_name]: | |
| target_acc = 0.75 | |
| reached = False | |
| for i, acc in enumerate(h['difficult_accuracies']): | |
| if acc >= target_acc: | |
| iterations_to_target.append(i) | |
| reached = True | |
| break | |
| if not reached: | |
| iterations_to_target.append(len(h['difficult_accuracies'])) | |
| mean_final = np.mean(final_accs) | |
| std_final = np.std(final_accs) | |
| mean_iters = np.mean(iterations_to_target) | |
| std_iters = np.std(iterations_to_target) | |
| print(f"\n{strategy_name}:") | |
| print(f" Final Accuracy: {mean_final:.3f} ± {std_final:.3f} (range: {min(final_accs):.3f} - {max(final_accs):.3f})") | |
| print(f" Iterations to Target: {mean_iters:.1f} ± {std_iters:.1f} (range: {min(iterations_to_target)} - {max(iterations_to_target)})") | |
| # Use first run for plotting (or could average) | |
| history_random = all_results['Random'][0] | |
| history_progressive = all_results['Progressive'][0] | |
| history_teacher = all_results['Teacher'][0] | |
| else: | |
| # Single run | |
| # Train all three strategies | |
| print("Training Random Strategy...") | |
| history_random = train_strategy_random(num_iterations=num_iterations, seed=seed) | |
| print("\nTraining Progressive Strategy...") | |
| history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=seed) | |
| print("\nTraining Teacher Strategy...") | |
| history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=seed) | |
| # Create comparison plots | |
| print("\nGenerating comparison plots...") | |
| histories = { | |
| 'Random': history_random, | |
| 'Progressive': history_progressive, | |
| 'Teacher': history_teacher | |
| } | |
| plot_comparison(histories, save_path='comparison_all_strategies.png') | |
| print("\n✅ Comparison complete! Check 'comparison_all_strategies.png'") | |
| if not args.deterministic and args.seed is None: | |
| print(f"💡 Tip: Results vary each run. Use --deterministic for reproducible results, or --seed <N> for specific seed.") | |