AI_Personas / src /influence /dynamics.py
Claude
Add peft package for Be.FM PEFT adapter loading
bde936d unverified
"""Opinion dynamics engine for multi-persona interactions"""
from typing import List, Dict, Optional, Callable
import statistics
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from ..personas.models import Persona
from ..personas.database import PersonaDatabase
from ..llm.anthropic_client import AnthropicClient
from ..llm.prompt_builder import PromptBuilder
from .models import PersonaOpinion, RoundResult, OpinionPosition
from .network import InfluenceNetwork
class OpinionDynamicsEngine:
"""
Runs multi-round opinion dynamics where personas influence each other.
Process:
1. Round 1: Each persona responds independently
2. Round 2+: Personas see others' responses and can update their position
3. Continue until equilibrium or max rounds reached
"""
def __init__(
self,
persona_db: PersonaDatabase,
llm_client: AnthropicClient,
max_workers: int = 3,
requests_per_minute: int = 50,
):
self.persona_db = persona_db
self.llm_client = llm_client
self.prompt_builder = PromptBuilder()
self.max_workers = max_workers
self.min_delay = 60.0 / requests_per_minute
def run_dynamics(
self,
question: str,
max_rounds: int = 5,
convergence_threshold: float = 0.1,
network_type: str = "scale_free",
persona_ids: Optional[List[str]] = None,
personas: Optional[List[Persona]] = None,
progress_callback: Optional[Callable[[str], None]] = None,
) -> List[RoundResult]:
"""
Run opinion dynamics simulation.
Args:
question: The proposal/question to discuss
max_rounds: Maximum number of rounds
convergence_threshold: Stop if total change < threshold
network_type: Network topology ("scale_free", "small_world", "fully_connected")
persona_ids: List of persona IDs to include (for small group mode)
personas: Pre-loaded personas list (for population mode)
progress_callback: Function to call with progress updates
Returns:
List of RoundResult for each round
"""
# Load personas from IDs if not provided directly
if personas is None:
if persona_ids is None:
raise ValueError("Either persona_ids or personas must be provided")
personas = [self.persona_db.get_persona(pid) for pid in persona_ids]
# Build influence network
influence_network = InfluenceNetwork(personas, network_type=network_type)
results = []
previous_opinions: Dict[str, PersonaOpinion] = {}
for round_num in range(1, max_rounds + 1):
if progress_callback:
progress_callback(
f"Round {round_num}/{max_rounds}: Gathering opinions..."
)
# Query all personas for this round
round_opinions = self._query_round(
personas=personas,
question=question,
round_number=round_num,
previous_opinions=previous_opinions,
influence_network=influence_network,
progress_callback=progress_callback,
)
# Calculate round metrics
round_result = self._analyze_round(
round_num, round_opinions, previous_opinions
)
results.append(round_result)
if progress_callback:
progress_callback(
f"Round {round_num} complete: "
f"Avg position: {round_result.average_position:.2f}, "
f"Total change: {round_result.total_change:.2f}"
)
# Check for convergence
if round_num > 1 and round_result.total_change < convergence_threshold:
if progress_callback:
progress_callback(
f"✓ Equilibrium reached at round {round_num}!"
)
break
# Update for next round
previous_opinions = {
op.persona_id: op for op in round_opinions
}
return results
def _query_round(
self,
personas: List[Persona],
question: str,
round_number: int,
previous_opinions: Dict[str, PersonaOpinion],
influence_network: InfluenceNetwork,
progress_callback: Optional[Callable[[str], None]] = None,
) -> List[PersonaOpinion]:
"""Query all personas for one round"""
opinions = []
# Parallel querying with rate limiting
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_persona = {}
for persona in personas:
future = executor.submit(
self._query_persona,
persona=persona,
question=question,
round_number=round_number,
previous_opinions=previous_opinions,
influence_network=influence_network,
)
future_to_persona[future] = persona
time.sleep(self.min_delay) # Rate limiting
for future in as_completed(future_to_persona):
persona = future_to_persona[future]
try:
opinion = future.result()
opinions.append(opinion)
if progress_callback:
progress_callback(
f" ✓ {persona.name}: {opinion.position.value}"
)
except Exception as e:
import traceback
error_msg = f"Error querying {persona.name}: {e}"
full_traceback = traceback.format_exc()
print(f"\n{'='*60}")
print(f"ERROR querying {persona.name}:")
print(full_traceback)
print('='*60)
if progress_callback:
progress_callback(f" ✗ {error_msg}")
return opinions
def _query_persona(
self,
persona: Persona,
question: str,
round_number: int,
previous_opinions: Dict[str, PersonaOpinion],
influence_network: InfluenceNetwork,
) -> PersonaOpinion:
"""Query one persona for their opinion"""
# Build prompt based on round
if round_number == 1:
# Round 1: Independent response
system_prompt = self._build_initial_prompt(persona)
user_message = f"""Please share your position on the following proposal:
{question}
Provide your response in this format:
POSITION: [strongly_support/support/lean_support/neutral/lean_oppose/oppose/strongly_oppose]
CONFIDENCE: [0.0-1.0]
RESPONSE: [Your detailed response explaining your position]
KEY_ARGUMENTS: [Bullet points of your main arguments]
"""
else:
# Round 2+: Consider others' opinions
system_prompt = self._build_influenced_prompt(
persona, previous_opinions, influence_network
)
user_message = f"""After hearing other stakeholders' perspectives, please share your updated position on:
{question}
Consider the arguments you've heard, but stay true to your values and expertise.
Provide your response in this format:
POSITION: [strongly_support/support/lean_support/neutral/lean_oppose/oppose/strongly_oppose]
CONFIDENCE: [0.0-1.0]
RESPONSE: [Your detailed response]
KEY_ARGUMENTS: [Bullet points of your main arguments]
INFLUENCED_BY: [Names of stakeholders whose arguments influenced you, if any]
"""
# Get LLM response
response_text = self.llm_client.generate_response(
system_prompt=system_prompt,
user_message=user_message,
temperature=0.7,
)
# Debug logging for problematic responses
if not any(keyword in response_text for keyword in ["POSITION:", "position:", "Position:"]):
print(f"\n⚠️ WARNING: Response missing POSITION marker for {persona.name}")
print(f"Response preview: {response_text[:200]}...")
# Parse response
opinion = self._parse_response(
persona_id=persona.persona_id,
persona_name=persona.name,
round_number=round_number,
response_text=response_text,
previous_opinions=previous_opinions,
)
return opinion
def _build_initial_prompt(self, persona: Persona) -> str:
"""Build system prompt for initial round"""
return f"""{persona.get_context_summary()}
You are participating in a stakeholder discussion about urban planning.
Share your honest perspective based on your values, expertise, and background.
"""
def _build_influenced_prompt(
self,
persona: Persona,
previous_opinions: Dict[str, PersonaOpinion],
influence_network: InfluenceNetwork,
) -> str:
"""Build system prompt showing other opinions"""
base_prompt = self._build_initial_prompt(persona)
# Get influencers (people who might sway this persona)
influencers = influence_network.get_influencers(
persona.persona_id, min_weight=0.4
)
if not influencers or not previous_opinions:
return base_prompt
# Show relevant opinions
others_opinions = "\n\n---\n\nOTHER STAKEHOLDERS' PERSPECTIVES:\n\n"
for influence_weight in influencers[:5]: # Top 5 influencers
influencer_id = influence_weight.influencer_id
if influencer_id in previous_opinions:
opinion = previous_opinions[influencer_id]
others_opinions += f"""
{opinion.persona_name} ({opinion.position.value}):
{opinion.response_text[:300]}...
"""
return base_prompt + others_opinions
def _parse_response(
self,
persona_id: str,
persona_name: str,
round_number: int,
response_text: str,
previous_opinions: Dict[str, PersonaOpinion],
) -> PersonaOpinion:
"""Parse LLM response into PersonaOpinion"""
# Extract position (case-insensitive)
position = OpinionPosition.NEUTRAL
for line in response_text.split("\n"):
line_upper = line.upper()
if line_upper.startswith("POSITION:"):
position_str = line.split(":", 1)[1].strip().lower()
try:
position = OpinionPosition(position_str)
except:
# Try matching partial strings
if "strong" in position_str and "support" in position_str:
position = OpinionPosition.STRONGLY_SUPPORT
elif "strong" in position_str and "oppos" in position_str:
position = OpinionPosition.STRONGLY_OPPOSE
elif "support" in position_str:
position = OpinionPosition.SUPPORT if "lean" not in position_str else OpinionPosition.LEAN_SUPPORT
elif "oppos" in position_str:
position = OpinionPosition.OPPOSE if "lean" not in position_str else OpinionPosition.LEAN_OPPOSE
elif "neutral" in position_str:
position = OpinionPosition.NEUTRAL
# Extract confidence (case-insensitive)
confidence = 0.5
for line in response_text.split("\n"):
line_upper = line.upper()
if line_upper.startswith("CONFIDENCE:"):
try:
conf_str = line.split(":", 1)[1].strip()
confidence = float(conf_str)
# Clamp to valid range
confidence = max(0.0, min(1.0, confidence))
except:
pass
# Extract key arguments (case-insensitive)
key_arguments = []
in_arguments = False
for line in response_text.split("\n"):
line_upper = line.upper()
if line_upper.startswith("KEY_ARGUMENTS:") or line_upper.startswith("KEY ARGUMENTS:"):
in_arguments = True
continue
if in_arguments and line.strip().startswith("-"):
key_arguments.append(line.strip("- "))
elif in_arguments and line_upper.startswith("INFLUENCED"):
break
# Extract influenced by (case-insensitive)
influenced_by = []
for line in response_text.split("\n"):
line_upper = line.upper()
if line_upper.startswith("INFLUENCED"):
influenced_str = line.split(":", 1)[1].strip() if ":" in line else ""
# Parse names (could be comma-separated)
if influenced_str and influenced_str.lower() not in ["none", ""]:
influenced_by = [
name.strip() for name in influenced_str.split(",")
]
# Calculate position change
position_change = None
if persona_id in previous_opinions:
prev_score = previous_opinions[persona_id].position_score
position_change = position.score - prev_score
return PersonaOpinion(
persona_id=persona_id,
persona_name=persona_name,
round_number=round_number,
position=position,
position_score=position.score,
response_text=response_text,
key_arguments=key_arguments,
confidence=confidence,
influenced_by=influenced_by,
position_change=position_change,
)
def _analyze_round(
self,
round_number: int,
opinions: List[PersonaOpinion],
previous_opinions: Dict[str, PersonaOpinion],
) -> RoundResult:
"""Analyze results from one round"""
# Check if we have any opinions
if not opinions:
raise ValueError(
f"Round {round_number} failed: No opinions were successfully generated. "
"This may indicate that the LLM responses are not in the expected format. "
"Please check that responses include POSITION, CONFIDENCE, and other required fields."
)
# Calculate average position
avg_position = statistics.mean([op.position_score for op in opinions])
# Calculate variance
variance = statistics.variance([op.position_score for op in opinions]) if len(opinions) > 1 else 0.0
# Calculate total change from previous round
total_change = 0.0
if previous_opinions:
for opinion in opinions:
if opinion.position_change is not None:
total_change += abs(opinion.position_change)
# Convergence metric (1 = no change, 0 = maximum change)
max_possible_change = len(opinions) * 6 # Max change is 6 per person
convergence = 1.0 - (total_change / max_possible_change) if max_possible_change > 0 else 1.0
# Simple clustering by position
clusters = self._cluster_by_position(opinions)
return RoundResult(
round_number=round_number,
opinions=opinions,
average_position=avg_position,
position_variance=variance,
total_change=total_change,
convergence_metric=convergence,
clusters=clusters,
)
def _cluster_by_position(
self, opinions: List[PersonaOpinion]
) -> List[List[str]]:
"""Group personas with similar positions"""
# Group by position
position_groups: Dict[str, List[str]] = {}
for opinion in opinions:
pos = opinion.position.value
if pos not in position_groups:
position_groups[pos] = []
position_groups[pos].append(opinion.persona_id)
return list(position_groups.values())