Spaces:
Sleeping
Sleeping
| """ | |
| Experiment tracking module with extended database schema. | |
| Handles session management, decision logging, and chat interaction tracking. | |
| """ | |
| import sqlite3 | |
| import uuid | |
| import json | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Any | |
| from dataclasses import dataclass, asdict | |
| from contextlib import contextmanager | |
| DATABASE_PATH = "db/experiment.db" | |
| class SessionData: | |
| """Represents a participant session.""" | |
| participant_id: str | |
| session_start: str | |
| condition_name: str | |
| initial_portfolio: float | |
| current_portfolio: float | |
| scenarios_completed: int = 0 | |
| ai_advice_followed: int = 0 | |
| ai_advice_total: int = 0 | |
| total_chat_queries: int = 0 | |
| proactive_advice_accepted: int = 0 | |
| proactive_advice_dismissed: int = 0 | |
| session_end: Optional[str] = None | |
| completed: bool = False | |
| class DecisionRecord: | |
| """Represents a single trading decision.""" | |
| decision_id: str | |
| participant_id: str | |
| timestamp: str | |
| scenario_id: str | |
| company_symbol: str | |
| # AI parameters at time of decision | |
| explanation_depth: int | |
| communication_style: int | |
| confidence_framing: int | |
| risk_bias: int | |
| # What happened | |
| ai_recommendation: str | |
| ai_was_correct: bool | |
| participant_decision: str | |
| followed_ai: bool | |
| # Confidence and timing | |
| decision_confidence: int | |
| time_to_decision_ms: int | |
| time_viewing_ai_advice_ms: int | |
| # Outcomes | |
| outcome_percentage: float | |
| portfolio_before: float | |
| portfolio_after: float | |
| trade_amount: float | |
| # Proactive advice | |
| proactive_advice_shown: bool | |
| proactive_advice_engaged: bool | |
| class ChatInteraction: | |
| """Represents a chat interaction with the AI.""" | |
| interaction_id: str | |
| participant_id: str | |
| timestamp: str | |
| scenario_id: Optional[str] | |
| # Interaction details | |
| interaction_type: str # "proactive", "reactive_query", "follow_up" | |
| user_query: Optional[str] | |
| ai_response: str | |
| # Parameters at time of interaction | |
| explanation_depth: int | |
| communication_style: int | |
| confidence_framing: int | |
| risk_bias: int | |
| # Engagement metrics | |
| response_time_ms: int | |
| user_engaged: bool # Did user respond/act on advice | |
| dismissed: bool # For proactive advice | |
| def get_db_connection(): | |
| """Context manager for database connections.""" | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| conn.row_factory = sqlite3.Row | |
| try: | |
| yield conn | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def init_database(): | |
| """Initialize the database with all required tables.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| # Sessions table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| participant_id TEXT PRIMARY KEY, | |
| session_start TEXT NOT NULL, | |
| session_end TEXT, | |
| condition_name TEXT NOT NULL, | |
| initial_portfolio REAL NOT NULL, | |
| current_portfolio REAL NOT NULL, | |
| scenarios_completed INTEGER DEFAULT 0, | |
| ai_advice_followed INTEGER DEFAULT 0, | |
| ai_advice_total INTEGER DEFAULT 0, | |
| total_chat_queries INTEGER DEFAULT 0, | |
| proactive_advice_accepted INTEGER DEFAULT 0, | |
| proactive_advice_dismissed INTEGER DEFAULT 0, | |
| completed INTEGER DEFAULT 0 | |
| ) | |
| """) | |
| # Decisions table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS decisions ( | |
| decision_id TEXT PRIMARY KEY, | |
| participant_id TEXT NOT NULL, | |
| timestamp TEXT NOT NULL, | |
| scenario_id TEXT NOT NULL, | |
| company_symbol TEXT NOT NULL, | |
| -- AI parameters | |
| explanation_depth INTEGER, | |
| communication_style INTEGER, | |
| confidence_framing INTEGER, | |
| risk_bias INTEGER, | |
| -- Decision details | |
| ai_recommendation TEXT, | |
| ai_was_correct INTEGER, | |
| participant_decision TEXT, | |
| followed_ai INTEGER, | |
| -- Confidence and timing | |
| decision_confidence INTEGER, | |
| time_to_decision_ms INTEGER, | |
| time_viewing_ai_advice_ms INTEGER, | |
| -- Outcomes | |
| outcome_percentage REAL, | |
| portfolio_before REAL, | |
| portfolio_after REAL, | |
| trade_amount REAL, | |
| -- Proactive advice | |
| proactive_advice_shown INTEGER, | |
| proactive_advice_engaged INTEGER, | |
| FOREIGN KEY (participant_id) REFERENCES sessions(participant_id) | |
| ) | |
| """) | |
| # Chat interactions table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS chat_interactions ( | |
| interaction_id TEXT PRIMARY KEY, | |
| participant_id TEXT NOT NULL, | |
| timestamp TEXT NOT NULL, | |
| scenario_id TEXT, | |
| -- Interaction details | |
| interaction_type TEXT NOT NULL, | |
| user_query TEXT, | |
| ai_response TEXT NOT NULL, | |
| -- AI parameters | |
| explanation_depth INTEGER, | |
| communication_style INTEGER, | |
| confidence_framing INTEGER, | |
| risk_bias INTEGER, | |
| -- Engagement metrics | |
| response_time_ms INTEGER, | |
| user_engaged INTEGER, | |
| dismissed INTEGER, | |
| FOREIGN KEY (participant_id) REFERENCES sessions(participant_id) | |
| ) | |
| """) | |
| # Trust metrics table (computed per scenario) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS trust_metrics ( | |
| metric_id TEXT PRIMARY KEY, | |
| participant_id TEXT NOT NULL, | |
| scenario_id TEXT NOT NULL, | |
| timestamp TEXT NOT NULL, | |
| -- Pre/post confidence | |
| pre_advice_confidence INTEGER, | |
| post_advice_confidence INTEGER, | |
| confidence_change INTEGER, | |
| -- Behavior indicators | |
| advice_followed INTEGER, | |
| time_deliberating_ms INTEGER, | |
| queries_before_decision INTEGER, | |
| -- Outcome | |
| outcome_positive INTEGER, | |
| FOREIGN KEY (participant_id) REFERENCES sessions(participant_id) | |
| ) | |
| """) | |
| # Experiment conditions table (for researcher reference) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS experiment_conditions ( | |
| condition_name TEXT PRIMARY KEY, | |
| accuracy_rate REAL, | |
| proactivity_level INTEGER, | |
| confidence_framing INTEGER, | |
| risk_bias INTEGER, | |
| description TEXT, | |
| created_at TEXT | |
| ) | |
| """) | |
| class ExperimentTracker: | |
| """Main class for tracking experiment data.""" | |
| def __init__(self): | |
| init_database() | |
| def create_session( | |
| self, | |
| condition_name: str, | |
| initial_portfolio: float | |
| ) -> str: | |
| """Create a new participant session and return the participant ID.""" | |
| participant_id = str(uuid.uuid4())[:8] # Short ID for display | |
| session_start = datetime.now().isoformat() | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO sessions ( | |
| participant_id, session_start, condition_name, | |
| initial_portfolio, current_portfolio | |
| ) VALUES (?, ?, ?, ?, ?) | |
| """, ( | |
| participant_id, session_start, condition_name, | |
| initial_portfolio, initial_portfolio | |
| )) | |
| return participant_id | |
| def get_session(self, participant_id: str) -> Optional[Dict]: | |
| """Retrieve session data for a participant.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT * FROM sessions WHERE participant_id = ?", | |
| (participant_id,) | |
| ) | |
| row = cursor.fetchone() | |
| if row: | |
| return dict(row) | |
| return None | |
| def update_session(self, participant_id: str, **kwargs): | |
| """Update session fields.""" | |
| if not kwargs: | |
| return | |
| set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys()) | |
| values = list(kwargs.values()) + [participant_id] | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| f"UPDATE sessions SET {set_clause} WHERE participant_id = ?", | |
| values | |
| ) | |
| def complete_session(self, participant_id: str, final_portfolio: float): | |
| """Mark a session as completed.""" | |
| self.update_session( | |
| participant_id, | |
| session_end=datetime.now().isoformat(), | |
| current_portfolio=final_portfolio, | |
| completed=1 | |
| ) | |
| def record_decision(self, record: DecisionRecord): | |
| """Record a trading decision.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO decisions ( | |
| decision_id, participant_id, timestamp, scenario_id, company_symbol, | |
| explanation_depth, communication_style, confidence_framing, risk_bias, | |
| ai_recommendation, ai_was_correct, participant_decision, followed_ai, | |
| decision_confidence, time_to_decision_ms, time_viewing_ai_advice_ms, | |
| outcome_percentage, portfolio_before, portfolio_after, trade_amount, | |
| proactive_advice_shown, proactive_advice_engaged | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| record.decision_id, record.participant_id, record.timestamp, | |
| record.scenario_id, record.company_symbol, | |
| record.explanation_depth, record.communication_style, | |
| record.confidence_framing, record.risk_bias, | |
| record.ai_recommendation, int(record.ai_was_correct), | |
| record.participant_decision, int(record.followed_ai), | |
| record.decision_confidence, record.time_to_decision_ms, | |
| record.time_viewing_ai_advice_ms, | |
| record.outcome_percentage, record.portfolio_before, | |
| record.portfolio_after, record.trade_amount, | |
| int(record.proactive_advice_shown), int(record.proactive_advice_engaged) | |
| )) | |
| # Update session counters | |
| session = self.get_session(record.participant_id) | |
| if session: | |
| updates = { | |
| "scenarios_completed": session["scenarios_completed"] + 1, | |
| "ai_advice_total": session["ai_advice_total"] + 1, | |
| "current_portfolio": record.portfolio_after | |
| } | |
| if record.followed_ai: | |
| updates["ai_advice_followed"] = session["ai_advice_followed"] + 1 | |
| if record.proactive_advice_shown: | |
| if record.proactive_advice_engaged: | |
| updates["proactive_advice_accepted"] = session["proactive_advice_accepted"] + 1 | |
| else: | |
| updates["proactive_advice_dismissed"] = session["proactive_advice_dismissed"] + 1 | |
| self.update_session(record.participant_id, **updates) | |
| def record_chat_interaction(self, interaction: ChatInteraction): | |
| """Record a chat interaction.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO chat_interactions ( | |
| interaction_id, participant_id, timestamp, scenario_id, | |
| interaction_type, user_query, ai_response, | |
| explanation_depth, communication_style, confidence_framing, risk_bias, | |
| response_time_ms, user_engaged, dismissed | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| interaction.interaction_id, interaction.participant_id, | |
| interaction.timestamp, interaction.scenario_id, | |
| interaction.interaction_type, interaction.user_query, | |
| interaction.ai_response, | |
| interaction.explanation_depth, interaction.communication_style, | |
| interaction.confidence_framing, interaction.risk_bias, | |
| interaction.response_time_ms, int(interaction.user_engaged), | |
| int(interaction.dismissed) | |
| )) | |
| # Update query count for reactive queries | |
| if interaction.interaction_type == "reactive_query": | |
| session = self.get_session(interaction.participant_id) | |
| if session: | |
| self.update_session( | |
| interaction.participant_id, | |
| total_chat_queries=session["total_chat_queries"] + 1 | |
| ) | |
| def record_trust_metric( | |
| self, | |
| participant_id: str, | |
| scenario_id: str, | |
| pre_confidence: int, | |
| post_confidence: int, | |
| advice_followed: bool, | |
| time_deliberating_ms: int, | |
| queries_before_decision: int, | |
| outcome_positive: bool | |
| ): | |
| """Record trust-related metrics for a scenario.""" | |
| metric_id = str(uuid.uuid4())[:12] | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| INSERT INTO trust_metrics ( | |
| metric_id, participant_id, scenario_id, timestamp, | |
| pre_advice_confidence, post_advice_confidence, confidence_change, | |
| advice_followed, time_deliberating_ms, queries_before_decision, | |
| outcome_positive | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| metric_id, participant_id, scenario_id, | |
| datetime.now().isoformat(), | |
| pre_confidence, post_confidence, post_confidence - pre_confidence, | |
| int(advice_followed), time_deliberating_ms, queries_before_decision, | |
| int(outcome_positive) | |
| )) | |
| def get_participant_decisions(self, participant_id: str) -> List[Dict]: | |
| """Get all decisions for a participant.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT * FROM decisions WHERE participant_id = ? ORDER BY timestamp", | |
| (participant_id,) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_participant_interactions(self, participant_id: str) -> List[Dict]: | |
| """Get all chat interactions for a participant.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT * FROM chat_interactions WHERE participant_id = ? ORDER BY timestamp", | |
| (participant_id,) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_session_summary(self, participant_id: str) -> Dict[str, Any]: | |
| """Get a summary of a participant's session.""" | |
| session = self.get_session(participant_id) | |
| if not session: | |
| return {} | |
| decisions = self.get_participant_decisions(participant_id) | |
| interactions = self.get_participant_interactions(participant_id) | |
| # Calculate metrics | |
| ai_follow_rate = ( | |
| session["ai_advice_followed"] / session["ai_advice_total"] | |
| if session["ai_advice_total"] > 0 else 0 | |
| ) | |
| proactive_engage_rate = ( | |
| session["proactive_advice_accepted"] / | |
| (session["proactive_advice_accepted"] + session["proactive_advice_dismissed"]) | |
| if (session["proactive_advice_accepted"] + session["proactive_advice_dismissed"]) > 0 | |
| else 0 | |
| ) | |
| portfolio_return = ( | |
| (session["current_portfolio"] - session["initial_portfolio"]) / | |
| session["initial_portfolio"] | |
| ) | |
| # Calculate average decision time | |
| avg_decision_time = ( | |
| sum(d["time_to_decision_ms"] for d in decisions) / len(decisions) | |
| if decisions else 0 | |
| ) | |
| return { | |
| "participant_id": participant_id, | |
| "condition": session["condition_name"], | |
| "completed": bool(session["completed"]), | |
| "scenarios_completed": session["scenarios_completed"], | |
| "initial_portfolio": session["initial_portfolio"], | |
| "final_portfolio": session["current_portfolio"], | |
| "portfolio_return": portfolio_return, | |
| "portfolio_return_pct": f"{portfolio_return * 100:.1f}%", | |
| "ai_follow_rate": ai_follow_rate, | |
| "ai_follow_rate_pct": f"{ai_follow_rate * 100:.1f}%", | |
| "proactive_engage_rate": proactive_engage_rate, | |
| "total_chat_queries": session["total_chat_queries"], | |
| "avg_decision_time_ms": avg_decision_time, | |
| "total_decisions": len(decisions), | |
| "total_interactions": len(interactions) | |
| } | |
| def get_all_sessions(self) -> List[Dict]: | |
| """Get all sessions for export/analysis.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM sessions ORDER BY session_start") | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_all_decisions(self) -> List[Dict]: | |
| """Get all decisions for export/analysis.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM decisions ORDER BY timestamp") | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_all_interactions(self) -> List[Dict]: | |
| """Get all chat interactions for export/analysis.""" | |
| with get_db_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM chat_interactions ORDER BY timestamp") | |
| return [dict(row) for row in cursor.fetchall()] | |
| # Singleton tracker instance | |
| tracker = ExperimentTracker() | |