pratik-250620's picture
Upload folder using huggingface_hub
6835659 verified
"""
Prompt Sampler Module
Provides stratified sampling of prompts for controlled experiments.
Ensures diversity across semantic categories while maintaining balance.
"""
from __future__ import annotations
import json
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Dict, Any, Optional
from collections import defaultdict
@dataclass
class PromptCategory:
"""A semantic category for prompt stratification."""
name: str
description: str
keywords: List[str] = field(default_factory=list)
examples: List[str] = field(default_factory=list)
# Default categories for multimodal scene prompts
DEFAULT_CATEGORIES = [
PromptCategory(
name="nature",
description="Natural environments, landscapes, wildlife",
keywords=["forest", "ocean", "mountain", "river", "wildlife", "sunset", "rain", "storm", "garden", "meadow"],
examples=[
"A calm foggy forest at dawn with distant birds and soft wind",
"Ocean waves crashing against rocky cliffs under a stormy sky",
"A peaceful meadow with wildflowers swaying in gentle breeze",
],
),
PromptCategory(
name="urban",
description="City scenes, street life, architecture",
keywords=["city", "street", "building", "traffic", "neon", "market", "crowd", "subway", "cafe", "downtown"],
examples=[
"A rainy neon-lit city street at night with reflections on wet pavement",
"Busy morning market with vendors and the sounds of commerce",
"Empty subway platform with echoing announcements",
],
),
PromptCategory(
name="abstract",
description="Abstract concepts, emotions, moods",
keywords=["peaceful", "chaotic", "melancholy", "joyful", "mysterious", "ethereal", "surreal", "dreamy"],
examples=[
"An ethereal dreamscape with floating lights and distant echoes",
"The feeling of nostalgia on a quiet autumn afternoon",
"A surreal landscape where gravity seems optional",
],
),
PromptCategory(
name="action",
description="Dynamic scenes with movement and activity",
keywords=["running", "flying", "dancing", "fighting", "racing", "jumping", "working", "playing"],
examples=[
"A horse galloping across an open prairie",
"Children playing in a park with laughter and shouts",
"Fireworks exploding over a celebration crowd",
],
),
PromptCategory(
name="domestic",
description="Indoor scenes, everyday life, home settings",
keywords=["kitchen", "bedroom", "office", "library", "home", "cozy", "fireplace", "window"],
examples=[
"A cozy reading nook with rain pattering on the window",
"Morning kitchen scene with sizzling breakfast sounds",
"A quiet home office with soft keyboard clicks",
],
),
]
@dataclass
class SampledPrompt:
"""A prompt with its category and metadata."""
text: str
category: str
source: str = "generated"
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"text": self.text,
"category": self.category,
"source": self.source,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SampledPrompt":
"""Create from dictionary."""
return cls(**data)
class PromptSampler:
"""
Stratified prompt sampler for controlled experiments.
Ensures balanced representation across semantic categories
while allowing for custom prompts and external datasets.
"""
def __init__(
self,
categories: Optional[List[PromptCategory]] = None,
seed: int = 42,
):
"""
Initialize sampler.
Args:
categories: List of semantic categories (uses defaults if None)
seed: Random seed for reproducibility
"""
self.categories = categories or DEFAULT_CATEGORIES
self.category_names = [c.name for c in self.categories]
self.rng = random.Random(seed)
# Build keyword index for categorization
self._keyword_to_category: Dict[str, str] = {}
for cat in self.categories:
for keyword in cat.keywords:
self._keyword_to_category[keyword.lower()] = cat.name
def categorize_prompt(self, prompt: str) -> str:
"""
Assign a category to a prompt based on keyword matching.
Args:
prompt: The prompt text
Returns:
Category name (or "other" if no match)
"""
prompt_lower = prompt.lower()
# Count keyword matches per category
category_scores: Dict[str, int] = defaultdict(int)
for keyword, category in self._keyword_to_category.items():
if keyword in prompt_lower:
category_scores[category] += 1
if category_scores:
return max(category_scores, key=lambda k: category_scores[k])
return "other"
def sample_stratified(
self,
n_total: int,
prompts_per_category: Optional[int] = None,
custom_prompts: Optional[List[str]] = None,
include_examples: bool = True,
) -> List[SampledPrompt]:
"""
Sample prompts with stratified distribution across categories.
Args:
n_total: Total number of prompts to sample
prompts_per_category: Override for prompts per category
custom_prompts: Additional prompts to include (will be categorized)
include_examples: Whether to include category example prompts
Returns:
List of SampledPrompt objects
"""
n_categories = len(self.categories)
per_category = prompts_per_category or (n_total // n_categories)
sampled: List[SampledPrompt] = []
# Organize custom prompts by category
custom_by_category: Dict[str, List[str]] = defaultdict(list)
if custom_prompts:
for prompt in custom_prompts:
category = self.categorize_prompt(prompt)
custom_by_category[category].append(prompt)
# Sample from each category
for cat in self.categories:
category_prompts: List[str] = []
# Add examples if enabled
if include_examples:
category_prompts.extend(cat.examples)
# Add custom prompts for this category
category_prompts.extend(custom_by_category.get(cat.name, []))
# Shuffle and select
self.rng.shuffle(category_prompts)
selected = category_prompts[:per_category]
for text in selected:
sampled.append(SampledPrompt(
text=text,
category=cat.name,
source="example" if text in cat.examples else "custom",
))
# Handle "other" category from custom prompts
other_prompts = custom_by_category.get("other", [])
if other_prompts:
self.rng.shuffle(other_prompts)
for text in other_prompts[:per_category]:
sampled.append(SampledPrompt(
text=text,
category="other",
source="custom",
))
# Shuffle final list and trim to exact n_total
self.rng.shuffle(sampled)
return sampled[:n_total]
def load_from_file(
self,
path: Path,
n_samples: Optional[int] = None,
) -> List[SampledPrompt]:
"""
Load and categorize prompts from a file.
Supports JSON (list or object with "prompts" key) and plain text (one per line).
Args:
path: Path to prompts file
n_samples: Optional limit on number of prompts
Returns:
List of SampledPrompt objects
"""
path = Path(path)
if path.suffix == ".json":
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
prompts = data
elif isinstance(data, dict) and "prompts" in data:
prompts = data["prompts"]
else:
raise ValueError(f"Unexpected JSON structure in {path}")
else:
# Plain text, one prompt per line
with path.open("r", encoding="utf-8") as f:
prompts = [line.strip() for line in f if line.strip()]
# Categorize and create SampledPrompt objects
sampled = []
for prompt in prompts:
if isinstance(prompt, dict):
text = prompt.get("text") or prompt.get("prompt", "")
category = prompt.get("category") or self.categorize_prompt(text)
else:
text = str(prompt)
category = self.categorize_prompt(text)
sampled.append(SampledPrompt(
text=text,
category=category,
source=str(path.name),
))
# Shuffle and limit
self.rng.shuffle(sampled)
if n_samples:
sampled = sampled[:n_samples]
return sampled
def load_from_laion(
self,
laion_dir: Path = Path("data/laion"),
n_samples: int = 50,
) -> List[SampledPrompt]:
"""
Load prompts from LAION subset.
Args:
laion_dir: Path to LAION data directory
n_samples: Number of prompts to sample
Returns:
List of SampledPrompt objects
"""
prompts_file = laion_dir / "prompts_500.json"
if not prompts_file.exists():
# Try alternative locations
for alt in [laion_dir / "prompts.json", laion_dir / "captions.json"]:
if alt.exists():
prompts_file = alt
break
if not prompts_file.exists():
raise FileNotFoundError(f"No prompts file found in {laion_dir}")
return self.load_from_file(prompts_file, n_samples)
def load_from_audiocaps(
self,
audiocaps_dir: Path = Path("data/audiocaps"),
n_samples: int = 50,
) -> List[SampledPrompt]:
"""
Load prompts from AudioCaps dataset.
Args:
audiocaps_dir: Path to AudioCaps data directory
n_samples: Number of prompts to sample
Returns:
List of SampledPrompt objects
"""
# Look for processed captions
for filename in ["captions.json", "audiocaps_subset.json", "prompts.json"]:
captions_file = audiocaps_dir / filename
if captions_file.exists():
return self.load_from_file(captions_file, n_samples)
raise FileNotFoundError(f"No captions file found in {audiocaps_dir}")
def get_category_distribution(
self,
prompts: List[SampledPrompt],
) -> Dict[str, int]:
"""
Get distribution of prompts across categories.
Args:
prompts: List of sampled prompts
Returns:
Dictionary mapping category names to counts
"""
distribution: Dict[str, int] = defaultdict(int)
for prompt in prompts:
distribution[prompt.category] += 1
return dict(distribution)
def save_sample_set(
self,
prompts: List[SampledPrompt],
path: Path,
):
"""
Save a sample set to JSON file.
Args:
prompts: List of sampled prompts
path: Output path
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
data = {
"n_prompts": len(prompts),
"distribution": self.get_category_distribution(prompts),
"prompts": [p.to_dict() for p in prompts],
}
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
@classmethod
def load_sample_set(cls, path: Path) -> List[SampledPrompt]:
"""
Load a previously saved sample set.
Args:
path: Path to JSON file
Returns:
List of SampledPrompt objects
"""
with Path(path).open("r", encoding="utf-8") as f:
data = json.load(f)
return [SampledPrompt.from_dict(p) for p in data["prompts"]]
def create_experiment_prompts(
n_prompts: int = 50,
output_path: Optional[Path] = None,
seed: int = 42,
custom_prompts: Optional[List[str]] = None,
) -> List[SampledPrompt]:
"""
Convenience function to create a balanced prompt set for experiments.
Args:
n_prompts: Total number of prompts
output_path: Optional path to save the prompt set
seed: Random seed
custom_prompts: Optional custom prompts to include
Returns:
List of SampledPrompt objects
"""
sampler = PromptSampler(seed=seed)
prompts = sampler.sample_stratified(
n_total=n_prompts,
custom_prompts=custom_prompts,
include_examples=True,
)
if output_path:
sampler.save_sample_set(prompts, output_path)
return prompts