Claude Claude commited on
Commit
b85786b
·
unverified ·
1 Parent(s): b4cdaab

Add persona_weights support to PopulationNetwork

Browse files

Fixes:
- Added persona_weights parameter to PopulationNetwork.__init__
- Updated _generate_population_variants to use custom weights when provided
- This allows Opinion Equilibria page to use custom persona distributions

Files modified:
- src/influence/population_network.py: Added persona_weights support

Note: Sidebar 'Web App' label issue persists - pages.toml may need Streamlit >1.40
to be fully functional. This is a cosmetic issue only.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. src/influence/population_network.py +30 -14
src/influence/population_network.py CHANGED
@@ -24,6 +24,7 @@ class PopulationNetwork:
24
  homophily: float = 0.5,
25
  variation_level: VariationLevel = VariationLevel.MODERATE,
26
  random_seed: int = None,
 
27
  ):
28
  """
29
  Initialize population network.
@@ -35,12 +36,14 @@ class PopulationNetwork:
35
  homophily: Homophily parameter (0-1, higher = more clustering)
36
  variation_level: How much to vary persona characteristics
37
  random_seed: Random seed for reproducibility
 
38
  """
39
  self.base_personas = base_personas
40
  self.population_size = population_size
41
  self.network_type = network_type
42
  self.homophily = homophily
43
  self.variation_level = variation_level
 
44
 
45
  if random_seed is not None:
46
  random.seed(random_seed)
@@ -65,20 +68,33 @@ class PopulationNetwork:
65
  """Generate population variants from base personas"""
66
  variants = []
67
 
68
- # Distribute population across base personas
69
- variants_per_base = self.population_size // len(self.base_personas)
70
- remainder = self.population_size % len(self.base_personas)
71
-
72
- for i, base_persona in enumerate(self.base_personas):
73
- # Generate variants for this base persona
74
- count = variants_per_base + (1 if i < remainder else 0)
75
-
76
- generator = VariantGenerator(base_persona, self.variation_level)
77
- persona_variants = [
78
- generator.generate_variant(f"_v{len(variants) + j}")
79
- for j in range(count)
80
- ]
81
- variants.extend(persona_variants)
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  return variants
84
 
 
24
  homophily: float = 0.5,
25
  variation_level: VariationLevel = VariationLevel.MODERATE,
26
  random_seed: int = None,
27
+ persona_weights: Dict[str, float] = None,
28
  ):
29
  """
30
  Initialize population network.
 
36
  homophily: Homophily parameter (0-1, higher = more clustering)
37
  variation_level: How much to vary persona characteristics
38
  random_seed: Random seed for reproducibility
39
+ persona_weights: Optional custom distribution weights for personas
40
  """
41
  self.base_personas = base_personas
42
  self.population_size = population_size
43
  self.network_type = network_type
44
  self.homophily = homophily
45
  self.variation_level = variation_level
46
+ self.persona_weights = persona_weights
47
 
48
  if random_seed is not None:
49
  random.seed(random_seed)
 
68
  """Generate population variants from base personas"""
69
  variants = []
70
 
71
+ if self.persona_weights:
72
+ # Use custom weights to distribute population
73
+ for base_persona in self.base_personas:
74
+ weight = self.persona_weights.get(base_persona.persona_id, 0)
75
+ count = int(round(weight * self.population_size))
76
+
77
+ generator = VariantGenerator(base_persona, self.variation_level)
78
+ persona_variants = [
79
+ generator.generate_variant(f"_v{len(variants) + j}")
80
+ for j in range(count)
81
+ ]
82
+ variants.extend(persona_variants)
83
+ else:
84
+ # Distribute population evenly across base personas
85
+ variants_per_base = self.population_size // len(self.base_personas)
86
+ remainder = self.population_size % len(self.base_personas)
87
+
88
+ for i, base_persona in enumerate(self.base_personas):
89
+ # Generate variants for this base persona
90
+ count = variants_per_base + (1 if i < remainder else 0)
91
+
92
+ generator = VariantGenerator(base_persona, self.variation_level)
93
+ persona_variants = [
94
+ generator.generate_variant(f"_v{len(variants) + j}")
95
+ for j in range(count)
96
+ ]
97
+ variants.extend(persona_variants)
98
 
99
  return variants
100