JoshTest / tracking.py
LittleMonkeyLab's picture
Upload 12 files
cbd168e verified
"""
Experiment tracking module with extended database schema.
Handles session management, decision logging, and chat interaction tracking.
"""
import sqlite3
import uuid
import json
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict
from contextlib import contextmanager
DATABASE_PATH = "db/experiment.db"
@dataclass
class SessionData:
"""Represents a participant session."""
participant_id: str
session_start: str
condition_name: str
initial_portfolio: float
current_portfolio: float
scenarios_completed: int = 0
ai_advice_followed: int = 0
ai_advice_total: int = 0
total_chat_queries: int = 0
proactive_advice_accepted: int = 0
proactive_advice_dismissed: int = 0
session_end: Optional[str] = None
completed: bool = False
@dataclass
class DecisionRecord:
"""Represents a single trading decision."""
decision_id: str
participant_id: str
timestamp: str
scenario_id: str
company_symbol: str
# AI parameters at time of decision
explanation_depth: int
communication_style: int
confidence_framing: int
risk_bias: int
# What happened
ai_recommendation: str
ai_was_correct: bool
participant_decision: str
followed_ai: bool
# Confidence and timing
decision_confidence: int
time_to_decision_ms: int
time_viewing_ai_advice_ms: int
# Outcomes
outcome_percentage: float
portfolio_before: float
portfolio_after: float
trade_amount: float
# Proactive advice
proactive_advice_shown: bool
proactive_advice_engaged: bool
@dataclass
class ChatInteraction:
"""Represents a chat interaction with the AI."""
interaction_id: str
participant_id: str
timestamp: str
scenario_id: Optional[str]
# Interaction details
interaction_type: str # "proactive", "reactive_query", "follow_up"
user_query: Optional[str]
ai_response: str
# Parameters at time of interaction
explanation_depth: int
communication_style: int
confidence_framing: int
risk_bias: int
# Engagement metrics
response_time_ms: int
user_engaged: bool # Did user respond/act on advice
dismissed: bool # For proactive advice
@contextmanager
def get_db_connection():
"""Context manager for database connections."""
conn = sqlite3.connect(DATABASE_PATH)
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def init_database():
"""Initialize the database with all required tables."""
with get_db_connection() as conn:
cursor = conn.cursor()
# Sessions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
participant_id TEXT PRIMARY KEY,
session_start TEXT NOT NULL,
session_end TEXT,
condition_name TEXT NOT NULL,
initial_portfolio REAL NOT NULL,
current_portfolio REAL NOT NULL,
scenarios_completed INTEGER DEFAULT 0,
ai_advice_followed INTEGER DEFAULT 0,
ai_advice_total INTEGER DEFAULT 0,
total_chat_queries INTEGER DEFAULT 0,
proactive_advice_accepted INTEGER DEFAULT 0,
proactive_advice_dismissed INTEGER DEFAULT 0,
completed INTEGER DEFAULT 0
)
""")
# Decisions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS decisions (
decision_id TEXT PRIMARY KEY,
participant_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
scenario_id TEXT NOT NULL,
company_symbol TEXT NOT NULL,
-- AI parameters
explanation_depth INTEGER,
communication_style INTEGER,
confidence_framing INTEGER,
risk_bias INTEGER,
-- Decision details
ai_recommendation TEXT,
ai_was_correct INTEGER,
participant_decision TEXT,
followed_ai INTEGER,
-- Confidence and timing
decision_confidence INTEGER,
time_to_decision_ms INTEGER,
time_viewing_ai_advice_ms INTEGER,
-- Outcomes
outcome_percentage REAL,
portfolio_before REAL,
portfolio_after REAL,
trade_amount REAL,
-- Proactive advice
proactive_advice_shown INTEGER,
proactive_advice_engaged INTEGER,
FOREIGN KEY (participant_id) REFERENCES sessions(participant_id)
)
""")
# Chat interactions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS chat_interactions (
interaction_id TEXT PRIMARY KEY,
participant_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
scenario_id TEXT,
-- Interaction details
interaction_type TEXT NOT NULL,
user_query TEXT,
ai_response TEXT NOT NULL,
-- AI parameters
explanation_depth INTEGER,
communication_style INTEGER,
confidence_framing INTEGER,
risk_bias INTEGER,
-- Engagement metrics
response_time_ms INTEGER,
user_engaged INTEGER,
dismissed INTEGER,
FOREIGN KEY (participant_id) REFERENCES sessions(participant_id)
)
""")
# Trust metrics table (computed per scenario)
cursor.execute("""
CREATE TABLE IF NOT EXISTS trust_metrics (
metric_id TEXT PRIMARY KEY,
participant_id TEXT NOT NULL,
scenario_id TEXT NOT NULL,
timestamp TEXT NOT NULL,
-- Pre/post confidence
pre_advice_confidence INTEGER,
post_advice_confidence INTEGER,
confidence_change INTEGER,
-- Behavior indicators
advice_followed INTEGER,
time_deliberating_ms INTEGER,
queries_before_decision INTEGER,
-- Outcome
outcome_positive INTEGER,
FOREIGN KEY (participant_id) REFERENCES sessions(participant_id)
)
""")
# Experiment conditions table (for researcher reference)
cursor.execute("""
CREATE TABLE IF NOT EXISTS experiment_conditions (
condition_name TEXT PRIMARY KEY,
accuracy_rate REAL,
proactivity_level INTEGER,
confidence_framing INTEGER,
risk_bias INTEGER,
description TEXT,
created_at TEXT
)
""")
class ExperimentTracker:
"""Main class for tracking experiment data."""
def __init__(self):
init_database()
def create_session(
self,
condition_name: str,
initial_portfolio: float
) -> str:
"""Create a new participant session and return the participant ID."""
participant_id = str(uuid.uuid4())[:8] # Short ID for display
session_start = datetime.now().isoformat()
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO sessions (
participant_id, session_start, condition_name,
initial_portfolio, current_portfolio
) VALUES (?, ?, ?, ?, ?)
""", (
participant_id, session_start, condition_name,
initial_portfolio, initial_portfolio
))
return participant_id
def get_session(self, participant_id: str) -> Optional[Dict]:
"""Retrieve session data for a participant."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM sessions WHERE participant_id = ?",
(participant_id,)
)
row = cursor.fetchone()
if row:
return dict(row)
return None
def update_session(self, participant_id: str, **kwargs):
"""Update session fields."""
if not kwargs:
return
set_clause = ", ".join(f"{k} = ?" for k in kwargs.keys())
values = list(kwargs.values()) + [participant_id]
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"UPDATE sessions SET {set_clause} WHERE participant_id = ?",
values
)
def complete_session(self, participant_id: str, final_portfolio: float):
"""Mark a session as completed."""
self.update_session(
participant_id,
session_end=datetime.now().isoformat(),
current_portfolio=final_portfolio,
completed=1
)
def record_decision(self, record: DecisionRecord):
"""Record a trading decision."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO decisions (
decision_id, participant_id, timestamp, scenario_id, company_symbol,
explanation_depth, communication_style, confidence_framing, risk_bias,
ai_recommendation, ai_was_correct, participant_decision, followed_ai,
decision_confidence, time_to_decision_ms, time_viewing_ai_advice_ms,
outcome_percentage, portfolio_before, portfolio_after, trade_amount,
proactive_advice_shown, proactive_advice_engaged
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
record.decision_id, record.participant_id, record.timestamp,
record.scenario_id, record.company_symbol,
record.explanation_depth, record.communication_style,
record.confidence_framing, record.risk_bias,
record.ai_recommendation, int(record.ai_was_correct),
record.participant_decision, int(record.followed_ai),
record.decision_confidence, record.time_to_decision_ms,
record.time_viewing_ai_advice_ms,
record.outcome_percentage, record.portfolio_before,
record.portfolio_after, record.trade_amount,
int(record.proactive_advice_shown), int(record.proactive_advice_engaged)
))
# Update session counters
session = self.get_session(record.participant_id)
if session:
updates = {
"scenarios_completed": session["scenarios_completed"] + 1,
"ai_advice_total": session["ai_advice_total"] + 1,
"current_portfolio": record.portfolio_after
}
if record.followed_ai:
updates["ai_advice_followed"] = session["ai_advice_followed"] + 1
if record.proactive_advice_shown:
if record.proactive_advice_engaged:
updates["proactive_advice_accepted"] = session["proactive_advice_accepted"] + 1
else:
updates["proactive_advice_dismissed"] = session["proactive_advice_dismissed"] + 1
self.update_session(record.participant_id, **updates)
def record_chat_interaction(self, interaction: ChatInteraction):
"""Record a chat interaction."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO chat_interactions (
interaction_id, participant_id, timestamp, scenario_id,
interaction_type, user_query, ai_response,
explanation_depth, communication_style, confidence_framing, risk_bias,
response_time_ms, user_engaged, dismissed
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
interaction.interaction_id, interaction.participant_id,
interaction.timestamp, interaction.scenario_id,
interaction.interaction_type, interaction.user_query,
interaction.ai_response,
interaction.explanation_depth, interaction.communication_style,
interaction.confidence_framing, interaction.risk_bias,
interaction.response_time_ms, int(interaction.user_engaged),
int(interaction.dismissed)
))
# Update query count for reactive queries
if interaction.interaction_type == "reactive_query":
session = self.get_session(interaction.participant_id)
if session:
self.update_session(
interaction.participant_id,
total_chat_queries=session["total_chat_queries"] + 1
)
def record_trust_metric(
self,
participant_id: str,
scenario_id: str,
pre_confidence: int,
post_confidence: int,
advice_followed: bool,
time_deliberating_ms: int,
queries_before_decision: int,
outcome_positive: bool
):
"""Record trust-related metrics for a scenario."""
metric_id = str(uuid.uuid4())[:12]
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO trust_metrics (
metric_id, participant_id, scenario_id, timestamp,
pre_advice_confidence, post_advice_confidence, confidence_change,
advice_followed, time_deliberating_ms, queries_before_decision,
outcome_positive
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
metric_id, participant_id, scenario_id,
datetime.now().isoformat(),
pre_confidence, post_confidence, post_confidence - pre_confidence,
int(advice_followed), time_deliberating_ms, queries_before_decision,
int(outcome_positive)
))
def get_participant_decisions(self, participant_id: str) -> List[Dict]:
"""Get all decisions for a participant."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM decisions WHERE participant_id = ? ORDER BY timestamp",
(participant_id,)
)
return [dict(row) for row in cursor.fetchall()]
def get_participant_interactions(self, participant_id: str) -> List[Dict]:
"""Get all chat interactions for a participant."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM chat_interactions WHERE participant_id = ? ORDER BY timestamp",
(participant_id,)
)
return [dict(row) for row in cursor.fetchall()]
def get_session_summary(self, participant_id: str) -> Dict[str, Any]:
"""Get a summary of a participant's session."""
session = self.get_session(participant_id)
if not session:
return {}
decisions = self.get_participant_decisions(participant_id)
interactions = self.get_participant_interactions(participant_id)
# Calculate metrics
ai_follow_rate = (
session["ai_advice_followed"] / session["ai_advice_total"]
if session["ai_advice_total"] > 0 else 0
)
proactive_engage_rate = (
session["proactive_advice_accepted"] /
(session["proactive_advice_accepted"] + session["proactive_advice_dismissed"])
if (session["proactive_advice_accepted"] + session["proactive_advice_dismissed"]) > 0
else 0
)
portfolio_return = (
(session["current_portfolio"] - session["initial_portfolio"]) /
session["initial_portfolio"]
)
# Calculate average decision time
avg_decision_time = (
sum(d["time_to_decision_ms"] for d in decisions) / len(decisions)
if decisions else 0
)
return {
"participant_id": participant_id,
"condition": session["condition_name"],
"completed": bool(session["completed"]),
"scenarios_completed": session["scenarios_completed"],
"initial_portfolio": session["initial_portfolio"],
"final_portfolio": session["current_portfolio"],
"portfolio_return": portfolio_return,
"portfolio_return_pct": f"{portfolio_return * 100:.1f}%",
"ai_follow_rate": ai_follow_rate,
"ai_follow_rate_pct": f"{ai_follow_rate * 100:.1f}%",
"proactive_engage_rate": proactive_engage_rate,
"total_chat_queries": session["total_chat_queries"],
"avg_decision_time_ms": avg_decision_time,
"total_decisions": len(decisions),
"total_interactions": len(interactions)
}
def get_all_sessions(self) -> List[Dict]:
"""Get all sessions for export/analysis."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM sessions ORDER BY session_start")
return [dict(row) for row in cursor.fetchall()]
def get_all_decisions(self) -> List[Dict]:
"""Get all decisions for export/analysis."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM decisions ORDER BY timestamp")
return [dict(row) for row in cursor.fetchall()]
def get_all_interactions(self) -> List[Dict]:
"""Get all chat interactions for export/analysis."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM chat_interactions ORDER BY timestamp")
return [dict(row) for row in cursor.fetchall()]
# Singleton tracker instance
tracker = ExperimentTracker()