hallucination-guard-env / server /environment.py
SamSankar's picture
Upload environment.py
f0ad00b verified
"""Professional-grade HallucinationGuard RL Environment.
This module implements a sophisticated, production-ready RL environment with:
- Curriculum learning with adaptive difficulty
- Multi-turn conversation support
- Context retrieval challenges
- Comprehensive episode management
- Model-agnostic design (works with any LLM)
- Real-time metrics and logging
- Session management for concurrent users
"""
import uuid
import time
import logging
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass, field
from enum import Enum
# Add directories to path for imports to work in both local and HF Spaces
import sys
import os
_dir = os.path.dirname(os.path.abspath(__file__))
_parent = os.path.dirname(_dir)
if _parent not in sys.path:
sys.path.insert(0, _parent)
if _dir not in sys.path:
sys.path.insert(0, _dir)
from openenv.core.env_server import Environment
from models import (
HallucinationAction,
HallucinationObservation,
HallucinationState,
EpisodeStatistics,
AgentSkillProfile,
RewardBreakdown,
SemanticAnalysis,
CitationAnalysis,
HallucinationSeverity,
HallucinationType,
DifficultyLevel,
EnvironmentConfig,
MultiTurnDialogue,
)
# Import from same directory for HF Spaces deployment compatibility
from grader import (
calculate_reward,
generate_feedback,
detect_hallucination_advanced,
HallucinationType as GraderHallucinationType,
HallucinationSeverity as GraderHallucinationSeverity,
)
from dataset_loader import DatasetLoader, QAExample, DifficultyLevel as DatasetDifficulty
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class EpisodePhase(Enum):
"""Phases of an episode."""
INITIALIZATION = "initialization"
ACTIVE = "active"
MULTI_TURN_CLARIFICATION = "multi_turn_clarification"
CONTEXT_RETRIEVAL = "context_retrieval"
COMPLETION = "completion"
class HallucinationEnvironment(Environment[HallucinationAction, HallucinationObservation, HallucinationState]):
"""
Professional-grade OpenEnv environment for training AI to avoid hallucinations.
Features:
- Curriculum learning with progressive difficulty
- Adaptive difficulty based on performance
- Multi-turn conversation support
- Context retrieval challenges
- Comprehensive metrics tracking
- Model-agnostic design
- Session management
"""
SUPPORTS_CONCURRENT_SESSIONS = True
VERSION = "2.0.0"
def __init__(
self,
transform=None,
config: Optional[EnvironmentConfig] = None,
session_id: Optional[str] = None
):
super().__init__(transform=transform)
# Configuration
self.config = config or EnvironmentConfig()
self.session_id = session_id or str(uuid.uuid4())[:8]
# Dataset management — load synthetic baseline, then augment with real HF data
self.dataset_loader = DatasetLoader()
self.dataset_loader.load_builtin_datasets()
logger.info(f"Synthetic dataset: {self.dataset_loader.get_total_examples()} examples")
# Attempt to load real HuggingFace datasets (SQuAD, TriviaQA, HaluEval, TruthfulQA).
# Uses disk cache after first download so restarts are instant.
# Gracefully skips if the `datasets` package is not installed.
try:
real_added = self.dataset_loader.load_real_datasets(max_per_dataset=500, cache=True)
if real_added > 0:
logger.info(f"Added {real_added} real examples — total: {self.dataset_loader.get_total_examples()}")
else:
logger.info("HuggingFace datasets unavailable; using synthetic data only")
except Exception as _ds_err:
logger.warning(f"Dataset loading failed ({_ds_err}); continuing with synthetic data only")
# Episode state
self.episode_id: Optional[str] = None
self.episode_phase: EpisodePhase = EpisodePhase.INITIALIZATION
self.step_count: int = 0
self.total_hallucinations: int = 0
self.total_correct: int = 0
self.total_partial: int = 0
# Current data
self.current_example: Optional[QAExample] = None
self.episode_examples: List[QAExample] = []
self.episode_start_time: Optional[float] = None
self.last_step_time: Optional[float] = None
# Performance tracking
self.reward_history: List[float] = []
self.confidence_history: List[float] = []
self.hallucination_history: List[bool] = []
self.current_streak: int = 0
self.best_streak: int = 0
# Early stopping tracking (NEW)
self.consecutive_failures: int = 0
self.consecutive_hallucinations: int = 0
self.consecutive_perfect: int = 0
self.early_stop_reason: Optional[str] = None
self.calibration_history: List[float] = []
# Curriculum state
self.curriculum_stage: int = 0
self.curriculum_performance: List[float] = []
self.skill_rating: float = 0.5 # ELO-style rating
# Multi-turn state
self.dialogue: Optional[MultiTurnDialogue] = None
self.pending_clarifications: List[str] = []
# Agent profile (persistent across episodes)
self.agent_profile: Optional[AgentSkillProfile] = None
# Context retrieval challenge state
self.revealed_context_fragments: List[str] = []
self.context_retrieval_turns: int = 0
# Active model adapter (set via reset(model=...) for auto-play mode)
self.active_adapter = None
logger.info(f"Initialized HallucinationEnvironment (session={self.session_id})")
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
difficulty: Optional[str] = None,
enable_multi_turn: bool = False,
enable_context_retrieval: bool = False,
model: Optional[str] = None,
model_config: Optional[Dict[str, Any]] = None,
**kwargs
) -> HallucinationObservation:
"""
Reset the environment for a new episode.
Args:
seed: Random seed for reproducibility
episode_id: Custom episode ID
difficulty: Starting difficulty level
enable_multi_turn: Enable multi-turn clarification
enable_context_retrieval: Enable context retrieval challenges
model: Model provider to use for auto-play mode.
Supported: "openai", "anthropic", "huggingface", "ollama", "generic".
When set, the environment calls the model automatically on each step
so you only need to call reset() + step() in a loop.
model_config: Optional dict passed to create_adapter(). Keys:
model_name, api_key, api_base, temperature, max_tokens, etc.
Returns:
Initial observation
"""
import random
if seed is not None:
random.seed(seed)
# Reset used indices for reproducibility
self.dataset_loader.reset_usage()
# ── Model adapter setup ───────────────────────────────────────────────
# When model= is supplied, the environment auto-generates answers by
# calling the adapter inside step(), so callers just loop reset/step.
if model is not None:
try:
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model_adapters import create_adapter
cfg = model_config or {}
self.active_adapter = create_adapter(model, **cfg)
logger.info(f"Active adapter: {model} ({self.active_adapter.__class__.__name__})")
except Exception as e:
logger.warning(f"Could not create adapter for '{model}': {e}. Manual action mode.")
self.active_adapter = None
else:
self.active_adapter = None
# Generate episode ID
self.episode_id = episode_id or f"ep_{uuid.uuid4().hex[:8]}"
self.episode_start_time = time.time()
self.last_step_time = time.time()
# Reset counters
self.step_count = 0
self.total_hallucinations = 0
self.total_correct = 0
self.total_partial = 0
self.reward_history = []
self.confidence_history = []
self.hallucination_history = []
self.current_streak = 0
# Reset early stopping counters
self.consecutive_failures = 0
self.consecutive_hallucinations = 0
self.consecutive_perfect = 0
self.early_stop_reason = None
self.calibration_history = []
# Reset multi-turn state
self.dialogue = MultiTurnDialogue() if enable_multi_turn else None
self.pending_clarifications = []
# Reset context retrieval state
self.revealed_context_fragments = []
self.context_retrieval_turns = 0
# Determine starting difficulty
if difficulty:
try:
start_difficulty = DifficultyLevel(difficulty.lower())
except ValueError:
start_difficulty = self.config.initial_difficulty
elif self.config.adaptive_difficulty and self.agent_profile:
# Use agent's skill level
start_difficulty = self.agent_profile.difficulty_ceiling
else:
start_difficulty = self.config.initial_difficulty
# Load questions for this episode
mix_difficulties = self.config.curriculum_enabled and start_difficulty == DifficultyLevel.INTERMEDIATE
self.episode_examples = self.dataset_loader.start_new_episode(
num_questions=self.config.max_questions_per_episode,
difficulty=start_difficulty if not mix_difficulties else None,
mix_difficulties=mix_difficulties
)
if not self.episode_examples:
logger.error("No examples loaded for episode")
return self._create_error_observation("No questions available")
self.current_example = self.episode_examples[0]
self.episode_phase = EpisodePhase.ACTIVE
logger.info(f"Reset episode {self.episode_id} with {len(self.episode_examples)} questions")
return self._create_observation(
question=self.current_example.question,
context=self._get_context_for_observation(self.current_example),
feedback="Episode started. Answer using only the provided context.",
metadata={"phase": self.episode_phase.value}
)
def step(
self,
action: Optional[HallucinationAction] = None,
timeout_s: Optional[float] = None,
**kwargs
) -> HallucinationObservation:
"""
Process the AI's action and return the next observation.
Auto-play mode: if reset(model=...) was called, action can be None —
the environment calls the active adapter to generate an answer
automatically using the current question and context.
Manual mode: pass a HallucinationAction with answer, confidence, and
source_quote filled in (the normal RL training loop).
Handles:
- Standard Q&A steps
- Multi-turn clarifications
- Context retrieval challenges
"""
current_time = time.time()
step_duration = current_time - (self.last_step_time or current_time)
self.last_step_time = current_time
# ── Auto-play: generate action via active adapter ─────────────────────
if action is None or (not action.answer and self.active_adapter is not None):
if self.current_example is not None and self.active_adapter is not None:
try:
resp = self.active_adapter.generate_response(
question=self.current_example.question,
context=self.current_example.context,
require_citation=True,
require_confidence=True,
)
action = HallucinationAction(
answer=resp.answer,
confidence=resp.confidence,
source_quote=resp.source_quote or "",
reasoning=resp.reasoning or "",
)
logger.debug(f"Auto-play answer: {resp.answer[:80]}...")
except Exception as e:
logger.warning(f"Adapter generate_response failed: {e}")
action = HallucinationAction(answer="", confidence=0.5)
elif action is None:
action = HallucinationAction(answer="", confidence=0.5)
# Handle different episode phases
if self.episode_phase == EpisodePhase.MULTI_TURN_CLARIFICATION:
return self._handle_clarification_step(action)
elif self.episode_phase == EpisodePhase.CONTEXT_RETRIEVAL:
return self._handle_context_retrieval_step(action)
# Standard Q&A step
if self.current_example is None:
return self._end_episode()
# Validate action
if not action.answer and not action.requires_clarification:
return self._create_error_observation("No answer provided")
# Handle clarification request
if action.requires_clarification and self.dialogue:
return self._handle_clarification_request(action)
# Process the answer
return self._process_answer(action, step_duration)
def state(self) -> HallucinationState:
"""Return comprehensive state of the environment."""
# Calculate derived metrics
accuracy = self.total_correct / max(1, self.step_count)
hallucination_rate = self.total_hallucinations / max(1, self.step_count)
avg_confidence = sum(self.confidence_history) / max(1, len(self.confidence_history))
# Calculate calibration error
calibration_error = 0.0
if self.confidence_history and self.reward_history:
calibration_error = sum(
abs(c - r) for c, r in zip(self.confidence_history, self.reward_history)
) / len(self.confidence_history)
# Build episode statistics
episode_stats = EpisodeStatistics(
episode_id=self.episode_id or "",
total_questions=len(self.episode_examples),
questions_answered=self.step_count,
correct_answers=self.total_correct,
hallucinated_answers=self.total_hallucinations,
partially_correct=self.total_partial,
average_confidence=avg_confidence,
average_reward=sum(self.reward_history) / max(1, len(self.reward_history)),
calibration_error=calibration_error,
reward_history=self.reward_history.copy(),
)
return HallucinationState(
episode_id=self.episode_id,
session_id=self.session_id,
step_count=self.step_count,
max_questions=self.config.max_questions_per_episode,
total_hallucinations=self.total_hallucinations,
hallucination_rate=hallucination_rate,
total_correct=self.total_correct,
total_partial=self.total_partial,
accuracy=accuracy,
average_reward=sum(self.reward_history) / max(1, len(self.reward_history)),
average_confidence=avg_confidence,
calibration_error=calibration_error,
current_difficulty=self._get_current_difficulty(),
curriculum_stage=self.curriculum_stage,
skill_rating=self.skill_rating,
current_streak=self.current_streak,
best_streak=self.best_streak,
episode_stats=episode_stats.model_dump() if episode_stats else None,
agent_profile=self.agent_profile.model_dump() if self.agent_profile else None,
config={
"multi_turn_enabled": self.dialogue is not None,
"context_retrieval_enabled": self.config.enable_multi_turn,
"adaptive_difficulty": self.config.adaptive_difficulty,
},
episode_start_time=self.episode_start_time,
last_step_time=self.last_step_time,
metadata={
"phase": self.episode_phase.value,
"version": self.VERSION,
}
)
def close(self) -> None:
"""Clean up resources and save agent profile."""
if self.agent_profile:
self._update_agent_profile()
logger.info(f"Closed environment (session={self.session_id})")
def _process_answer(
self,
action: HallucinationAction,
step_duration: float
) -> HallucinationObservation:
"""Process a standard answer and compute rewards."""
# Get ground truth
ground_truth = self.current_example.answer
context = self.current_example.context
# Calculate reward using advanced grader
difficulty_str = self.current_example.difficulty.value if self.current_example else "intermediate"
prev_performance = self.skill_rating
reward, info = calculate_reward(
answer=action.answer,
confidence=action.confidence,
source_quote=action.source_quote,
context=context,
ground_truth=ground_truth,
difficulty_level=difficulty_str,
previous_performance=prev_performance,
reward_weights=self.config.reward_weights
)
# Extract metrics from info
is_hallucination = info.get("is_hallucination", False)
hallucination_type_str = info.get("hallucination_type", "none")
hallucination_severity_str = info.get("hallucination_severity", "NONE")
correctness = info.get("correctness", 0.0)
grounding_score = info.get("grounding", 0.0)
calibration_score = info.get("calibration", 0.0)
# Map hallucination type
try:
hallucination_type = HallucinationType(hallucination_type_str)
except ValueError:
hallucination_type = HallucinationType.NONE
# Map severity
try:
severity = HallucinationSeverity[hallucination_severity_str]
except KeyError:
severity = HallucinationSeverity.NONE
# Update statistics
if is_hallucination:
self.total_hallucinations += 1
self.current_streak = 0
self.consecutive_hallucinations += 1
self.consecutive_perfect = 0
elif correctness > 0.7:
self.total_correct += 1
self.current_streak += 1
self.best_streak = max(self.best_streak, self.current_streak)
self.consecutive_perfect += 1
self.consecutive_hallucinations = 0
self.consecutive_failures = 0
else:
self.total_partial += 1
self.current_streak = 0
self.consecutive_perfect = 0
self.consecutive_hallucinations = 0
if reward < self.config.early_stopping_min_reward:
self.consecutive_failures += 1
# Track calibration history
calibration_error = abs(action.confidence - correctness)
self.calibration_history.append(calibration_error)
# Track history
self.reward_history.append(reward)
self.confidence_history.append(action.confidence)
self.hallucination_history.append(is_hallucination)
# Update skill rating (ELO-style)
expected_score = 1 / (1 + 10 ** ((0.5 - self.skill_rating) * 4))
actual_score = 1.0 if correctness > 0.7 else (0.5 if correctness > 0.4 else 0.0)
self.skill_rating += 0.05 * (actual_score - expected_score)
self.skill_rating = max(0.0, min(1.0, self.skill_rating))
# Generate feedback
feedback = generate_feedback(
answer=action.answer,
ground_truth=ground_truth,
is_hallucination=is_hallucination,
hallucination_type=hallucination_type,
hallucination_severity=severity,
grounding_score=grounding_score,
correctness=correctness,
calibration_score=calibration_score,
total_reward=reward
)
# Move to next question
self.step_count += 1
# Check for early stopping conditions
early_stop = self._check_early_stopping(is_hallucination, correctness, calibration_error)
# Determine if episode is done
done = self.step_count >= self.config.max_questions_per_episode
if early_stop:
done = True
self.early_stop_reason = early_stop
self.episode_phase = EpisodePhase.COMPLETION
feedback += f" [Early stop: {early_stop}]"
if not done:
self.current_example = self.dataset_loader.get_example_for_step(self.step_count)
else:
self.current_example = None
self.episode_phase = EpisodePhase.COMPLETION
# Build observation
observation = self._create_observation(
question=self.current_example.question if self.current_example else "",
context=self._get_context_for_observation(self.current_example) if self.current_example else "",
ground_truth=ground_truth if done else "", # Only reveal at end
feedback=feedback,
reward=reward,
is_hallucination=is_hallucination,
hallucination_type=hallucination_type,
hallucination_severity=severity,
grounding_score=grounding_score,
done=done,
metadata={
"step": self.step_count,
"correctness": correctness,
"calibration": calibration_score,
"hallucination_score": info.get("hallucination_score", 0.0),
"reward_breakdown": self._extract_reward_breakdown(info),
"semantic_analysis": info.get("semantic_analysis", {}),
"citation_analysis": info.get("citation_analysis", {}),
}
)
# Update dialogue history if enabled
if self.dialogue:
self.dialogue.turn_number += 1
self.dialogue.conversation_history.append({
"question": observation.question,
"answer": action.answer,
"feedback": feedback
})
return observation
def _handle_clarification_request(
self,
action: HallucinationAction
) -> HallucinationObservation:
"""Handle a request for clarification."""
if not self.dialogue:
return self._create_error_observation("Multi-turn not enabled")
# Add clarification questions to pending list
self.pending_clarifications.extend(action.clarification_questions)
self.dialogue.unresolved_queries.extend(action.clarification_questions)
# Provide clarifications (simulated)
clarifications = []
for q in action.clarification_questions:
# Simple keyword-based clarification
clarification = self._generate_clarification(q, self.current_example)
clarifications.append(clarification)
if q in self.dialogue.unresolved_queries:
self.dialogue.unresolved_queries.remove(q)
# Switch to active phase
self.episode_phase = EpisodePhase.ACTIVE
return self._create_observation(
question=self.current_example.question if self.current_example else "",
context=self.current_example.context if self.current_example else "",
feedback=f"Clarifications provided: {'; '.join(clarifications)}",
metadata={
"clarifications": clarifications,
"phase": self.episode_phase.value
}
)
def _handle_clarification_step(
self,
action: HallucinationAction
) -> HallucinationObservation:
"""Handle a step during multi-turn clarification."""
# Process clarification and return to main question
self.episode_phase = EpisodePhase.ACTIVE
return self._process_answer(action, 0.0)
def _handle_context_retrieval_step(
self,
action: HallucinationAction
) -> HallucinationObservation:
"""Handle context retrieval challenge."""
# Reveal more context based on action
full_context = self.current_example.context if self.current_example else ""
context_fragments = self._split_context_into_fragments(full_context)
# Reveal additional fragments
new_revealed = min(
len(self.revealed_context_fragments) + 1,
len(context_fragments)
)
self.revealed_context_fragments = context_fragments[:new_revealed]
revealed_context = " ".join(self.revealed_context_fragments)
self.context_retrieval_turns += 1
# Check if enough context revealed or max turns reached
if self.context_retrieval_turns >= self.config.max_turns_per_question or \
new_revealed >= len(context_fragments):
self.episode_phase = EpisodePhase.ACTIVE
# Update current example with full context
if self.current_example:
self.current_example.metadata["revealed_context"] = revealed_context
else:
# Stay in retrieval phase
pass
return self._create_observation(
question=self.current_example.question if self.current_example else "",
context=revealed_context,
feedback=f"Context revealed: {new_revealed}/{len(context_fragments)} fragments",
metadata={
"fragments_revealed": new_revealed,
"total_fragments": len(context_fragments),
"phase": self.episode_phase.value
}
)
def _create_observation(
self,
question: str = "",
context: str = "",
ground_truth: str = "",
feedback: str = "",
reward: Optional[float] = None,
done: bool = False,
is_hallucination: bool = False,
hallucination_type: HallucinationType = HallucinationType.NONE,
hallucination_severity: HallucinationSeverity = HallucinationSeverity.NONE,
grounding_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None
) -> HallucinationObservation:
"""Create a comprehensive observation."""
accuracy_so_far = self.total_correct / max(1, self.step_count) if self.step_count > 0 else 0.0
# Extract reward breakdown from metadata if available
reward_breakdown = None
semantic_analysis = None
citation_analysis = None
if metadata:
reward_breakdown = metadata.get("reward_breakdown")
semantic_analysis = metadata.get("semantic_analysis")
citation_analysis = metadata.get("citation_analysis")
return HallucinationObservation(
question=question,
context=context,
ground_truth=ground_truth,
question_id=self.current_example.id if self.current_example else "",
source_dataset=self.current_example.source if self.current_example else "",
done=done,
reward=reward,
feedback=feedback,
is_hallucination=is_hallucination,
hallucination_type=hallucination_type,
hallucination_severity=hallucination_severity,
grounding_score=grounding_score,
accuracy_so_far=accuracy_so_far,
attempts_remaining=max(0, self.config.max_questions_per_episode - self.step_count),
current_streak=self.current_streak,
best_streak=self.best_streak,
difficulty_level=self._get_current_difficulty().value if hasattr(self._get_current_difficulty(), 'value') else str(self._get_current_difficulty()),
curriculum_progress=self.step_count / max(1, self.config.max_questions_per_episode),
skill_rating=self.skill_rating,
dialogue=self.dialogue,
reward_breakdown=reward_breakdown,
semantic_analysis=semantic_analysis,
citation_analysis=citation_analysis,
metadata=metadata or {}
)
def _create_error_observation(self, error_message: str) -> HallucinationObservation:
"""Create an error observation."""
return HallucinationObservation(
done=True,
reward=0.0,
question="",
context="",
feedback=f"Error: {error_message}",
is_hallucination=False,
grounding_score=0.0,
accuracy_so_far=0.0,
attempts_remaining=0,
reward_breakdown=None,
semantic_analysis=None,
citation_analysis=None,
metadata={"error": error_message}
)
def _end_episode(self) -> HallucinationObservation:
"""End the current episode."""
self.episode_phase = EpisodePhase.COMPLETION
# Update curriculum
self._update_curriculum()
return HallucinationObservation(
done=True,
reward=sum(self.reward_history) / max(1, len(self.reward_history)),
question="",
context="",
feedback=self._generate_episode_summary(),
is_hallucination=False,
grounding_score=0.0,
accuracy_so_far=self.total_correct / max(1, self.step_count),
attempts_remaining=0,
metadata={
"episode_complete": True,
"final_reward": sum(self.reward_history) / max(1, len(self.reward_history)),
"total_hallucinations": self.total_hallucinations,
"total_correct": self.total_correct,
}
)
def _check_early_stopping(self, is_hallucination: bool, correctness: float, calibration_error: float) -> Optional[str]:
"""
Check if episode should stop early based on performance conditions.
Returns:
str describing early stop reason, or None if should continue.
"""
if not self.config.early_stopping_enabled:
return None
# Require minimum steps before early stopping
if self.step_count < 3:
return None
# 1. Hallucination cascade: too many consecutive hallucinations
if self.consecutive_hallucinations >= self.config.early_stopping_hallucination_cascade:
return f"hallucination_cascade ({self.consecutive_hallucinations} consecutive)"
# 2. Consecutive failures: poor performance
if self.consecutive_failures >= self.config.early_stopping_patience:
return f"consecutive_failures ({self.consecutive_failures} below {self.config.early_stopping_min_reward})"
# 3. Calibration failure: confidence systematically misaligned
if len(self.calibration_history) >= 5:
avg_calibration_error = sum(self.calibration_history[-5:]) / 5
if avg_calibration_error > self.config.early_stopping_calibration_failure:
return f"calibration_failure (avg error: {avg_calibration_error:.2f})"
# 4. Perfect run: early completion after consistent high performance
if self.consecutive_perfect >= self.config.early_stopping_perfect_run:
if self.step_count >= self.config.min_questions_for_completion:
return f"perfect_run ({self.consecutive_perfect} consecutive correct)"
return None
def _get_context_for_observation(self, example: Optional[QAExample]) -> str:
"""Get context, potentially with partial revelation for challenges."""
if not example:
return ""
# Check if context retrieval is enabled
if self.config.enable_multi_turn and self.revealed_context_fragments:
return " ".join(self.revealed_context_fragments)
return example.context
def _get_current_difficulty(self) -> DifficultyLevel:
"""
Determine current difficulty based on performance with hysteresis.
Uses smooth difficulty scaling with:
- Stage-specific thresholds
- Minimum steps at each level (hysteresis)
- EXPERT level progression
"""
if not self.config.adaptive_difficulty:
return self.config.initial_difficulty
# Need enough history for reliable assessment
if len(self.reward_history) < 3:
return self.config.initial_difficulty
# Calculate recent performance with exponential weighting
recent_rewards = self.reward_history[-10:] if len(self.reward_history) >= 10 else self.reward_history
avg_recent_reward = sum(recent_rewards) / len(recent_rewards)
# Get current difficulty from example
current_difficulty = self.config.initial_difficulty
if self.current_example:
# Convert string to DifficultyLevel enum if needed
example_diff = self.current_example.difficulty
if isinstance(example_diff, str):
try:
current_difficulty = DifficultyLevel(example_diff.lower())
except ValueError:
current_difficulty = self.config.initial_difficulty
else:
current_difficulty = example_diff
# Stage-specific mastery thresholds
mastery_thresholds = {
DifficultyLevel.BEGINNER: 0.60,
DifficultyLevel.INTERMEDIATE: 0.65,
DifficultyLevel.ADVANCED: 0.75,
DifficultyLevel.EXPERT: 0.85,
}
# Regression thresholds (lower than mastery to avoid oscillation)
regression_thresholds = {
DifficultyLevel.BEGINNER: 0.30,
DifficultyLevel.INTERMEDIATE: 0.40,
DifficultyLevel.ADVANCED: 0.50,
DifficultyLevel.EXPERT: 0.60,
}
# Difficulty progression order
difficulty_order = [
DifficultyLevel.BEGINNER,
DifficultyLevel.INTERMEDIATE,
DifficultyLevel.ADVANCED,
DifficultyLevel.EXPERT,
]
current_idx = difficulty_order.index(current_difficulty) if current_difficulty in difficulty_order else 0
# Check for promotion
if avg_recent_reward > mastery_thresholds.get(current_difficulty, 0.7):
# Promote if not at EXPERT
if current_idx < len(difficulty_order) - 1:
return difficulty_order[current_idx + 1]
# Check for demotion
elif avg_recent_reward < regression_thresholds.get(current_difficulty, 0.4):
# Demote if not at BEGINNER
if current_idx > 0:
return difficulty_order[current_idx - 1]
return current_difficulty
def _update_curriculum(self) -> None:
"""
Update curriculum stage based on episode performance.
Supports:
- Advancement on sustained high performance
- Regression on sustained poor performance
- Stage-specific thresholds
"""
if not self.config.curriculum_enabled:
return
episode_reward = sum(self.reward_history) / max(1, len(self.reward_history))
self.curriculum_performance.append(episode_reward)
# Calculate statistics
avg_reward = sum(self.curriculum_performance) / len(self.curriculum_performance)
recent_rewards = self.curriculum_performance[-10:] if len(self.curriculum_performance) >= 10 else self.curriculum_performance
recent_avg = sum(recent_rewards) / len(recent_rewards)
# Stage-specific thresholds
advancement_threshold = self.config.curriculum_mastery_threshold
regression_threshold = self.config.curriculum_regression_threshold
# Check for curriculum advancement (sustained high performance)
if len(self.curriculum_performance) >= self.config.min_steps_per_curriculum_stage:
if recent_avg > advancement_threshold:
self.curriculum_stage += 1
self.curriculum_performance = [] # Reset for next stage
logger.info(f"Advanced to curriculum stage {self.curriculum_stage} (avg: {recent_avg:.2f})")
# Check for curriculum regression (sustained poor performance)
elif recent_avg < regression_threshold and self.curriculum_stage > 0:
self.curriculum_stage = max(0, self.curriculum_stage - 1)
self.curriculum_performance = []
logger.info(f"Regressed to curriculum stage {self.curriculum_stage} (avg: {recent_avg:.2f})")
def _update_agent_profile(self) -> None:
"""Update the agent's long-term skill profile."""
if not self.agent_profile:
self.agent_profile = AgentSkillProfile()
# Update metrics
total_steps = self.agent_profile.total_steps + self.step_count
weight = self.step_count / max(1, total_steps)
self.agent_profile.overall_accuracy = (
(1 - weight) * self.agent_profile.overall_accuracy +
weight * (self.total_correct / max(1, self.step_count))
)
self.agent_profile.grounding_skill = (
(1 - weight) * self.agent_profile.grounding_skill +
weight * sum(self.reward_history) / max(1, len(self.reward_history))
)
self.agent_profile.hallucination_rate = (
(1 - weight) * self.agent_profile.hallucination_rate +
weight * (self.total_hallucinations / max(1, self.step_count))
)
self.agent_profile.total_episodes += 1
self.agent_profile.total_steps = total_steps
# Update difficulty ceiling
if self.agent_profile.overall_accuracy > 0.8:
self.agent_profile.difficulty_ceiling = DifficultyLevel.EXPERT
elif self.agent_profile.overall_accuracy > 0.6:
self.agent_profile.difficulty_ceiling = DifficultyLevel.ADVANCED
elif self.agent_profile.overall_accuracy > 0.4:
self.agent_profile.difficulty_ceiling = DifficultyLevel.INTERMEDIATE
else:
self.agent_profile.difficulty_ceiling = DifficultyLevel.BEGINNER
def _generate_episode_summary(self) -> str:
"""Generate a summary of the completed episode."""
total_reward = sum(self.reward_history) / max(1, len(self.reward_history))
accuracy = self.total_correct / max(1, self.step_count)
summary_parts = [
f"Episode completed!",
f"Total reward: {total_reward:.2f}",
f"Accuracy: {accuracy:.1%}",
f"Hallucinations: {self.total_hallucinations}/{self.step_count}",
f"Best streak: {self.best_streak}",
]
if total_reward > 0.8:
summary_parts.append("Performance: OUTSTANDING!")
elif total_reward > 0.6:
summary_parts.append("Performance: Good")
elif total_reward > 0.4:
summary_parts.append("Performance: Needs improvement")
else:
summary_parts.append("Performance: Poor - review and recalibrate")
return " ".join(summary_parts)
def _extract_reward_breakdown(self, info: Dict[str, Any]) -> Dict[str, Any]:
"""Extract reward breakdown from grader info."""
components = info.get("components", {})
return {
"factual_correctness": info.get("correctness", 0.0),
"source_grounding": info.get("grounding", 0.0),
"citation_accuracy": info.get("citation_analysis", {}).get("best_match_score", 0.0),
"confidence_calibration": info.get("calibration", 0.0),
"semantic_consistency": info.get("semantic_consistency", 0.0),
"hallucination_penalty": info.get("hallucination_penalty", 0.0),
"total": info.get("total_reward", 0.0),
"difficulty_adjustment": info.get("difficulty_multiplier", 1.0),
"consistency_bonus": info.get("consistency_bonus", 0.0),
}
def _split_context_into_fragments(self, context: str, num_fragments: int = 5) -> List[str]:
"""Split context into fragments for retrieval challenges."""
if not context:
return []
sentences = context.split('.')
fragments = []
chunk_size = max(1, len(sentences) // num_fragments)
for i in range(0, len(sentences), chunk_size):
fragment = '.'.join(sentences[i:i + chunk_size]).strip()
if fragment:
fragments.append(fragment + '.')
return fragments or [context]
def _generate_clarification(self, question: str, example: Optional[QAExample]) -> str:
"""Generate a clarification response."""
if not example:
return "No context available for clarification."
# Simple keyword-based clarification
context_lower = example.context.lower()
question_lower = question.lower()
# Extract key terms from question
key_terms = [w for w in question_lower.split() if len(w) > 3 and w not in {'what', 'when', 'where', 'who', 'why', 'how', 'does', 'have', 'has', 'with', 'from'}]
clarifications = []
for term in key_terms[:3]:
if term in context_lower:
clarifications.append(f"Context mentions '{term}'")
return "; ".join(clarifications) if clarifications else "Review the provided context for relevant information."