gemotions / generate_stories.py
dejanseo's picture
Upload 10 files
67f0e56 verified
#!/usr/bin/env python3
"""Generate emotion-labeled stories using Gemini API.
171 emotions x 100 topics x 10 stories = 171,000 stories.
Concurrent API calls (up to 100), SQLite WAL for storage,
saves both raw API output and parsed stories.
Run:
python -m full_replication.generate_stories
python -m full_replication.generate_stories --test
python -m full_replication.generate_stories --workers 50
"""
import argparse
import os
import re
import sqlite3
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dotenv import load_dotenv
from google import genai
from google.genai import types
from tqdm import tqdm
from full_replication.config import EMOTIONS, TOPICS, STORY_PROMPT
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env"))
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "stories.db")
MODEL = "gemini-2.0-flash-lite"
STORIES_PER_CALL = 10
# Appended to Anthropic's prompt to enforce strict output format
FORMAT_SUFFIX = """
OUTPUT FORMAT: Start directly with [story 1] — no preamble, no introductions, no explanations, no commentary. Output ONLY the stories separated by [story N] markers. Nothing else."""
# Thread-local storage for DB connections
_local = threading.local()
_db_lock = threading.Lock()
def get_db():
"""Get thread-local DB connection with WAL mode."""
if not hasattr(_local, "conn"):
_local.conn = sqlite3.connect(DB_PATH, timeout=30)
_local.conn.execute("PRAGMA journal_mode=WAL")
_local.conn.execute("PRAGMA busy_timeout=10000")
return _local.conn
def init_db():
"""Create tables if they don't exist."""
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
conn = sqlite3.connect(DB_PATH)
conn.execute("PRAGMA journal_mode=WAL")
conn.executescript("""
CREATE TABLE IF NOT EXISTS api_calls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
emotion TEXT NOT NULL,
topic_idx INTEGER NOT NULL,
topic TEXT NOT NULL,
raw_response TEXT,
status TEXT DEFAULT 'pending',
error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(emotion, topic_idx)
);
CREATE TABLE IF NOT EXISTS stories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
api_call_id INTEGER NOT NULL,
emotion TEXT NOT NULL,
topic_idx INTEGER NOT NULL,
topic TEXT NOT NULL,
story_idx INTEGER NOT NULL,
text TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (api_call_id) REFERENCES api_calls(id),
UNIQUE(emotion, topic_idx, story_idx)
);
CREATE INDEX IF NOT EXISTS idx_stories_emotion ON stories(emotion);
CREATE INDEX IF NOT EXISTS idx_api_calls_status ON api_calls(status);
CREATE TABLE IF NOT EXISTS stories_clean (
id INTEGER PRIMARY KEY AUTOINCREMENT,
api_call_id INTEGER NOT NULL,
emotion TEXT NOT NULL,
topic_idx INTEGER NOT NULL,
topic TEXT NOT NULL,
story_idx INTEGER NOT NULL,
text TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (api_call_id) REFERENCES api_calls(id),
UNIQUE(emotion, topic_idx, story_idx)
);
CREATE INDEX IF NOT EXISTS idx_stories_clean_emotion ON stories_clean(emotion);
""")
conn.commit()
conn.close()
def get_completed():
"""Return set of (emotion, topic_idx) already in stories_clean."""
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.execute("PRAGMA journal_mode=WAL")
rows = conn.execute(
"SELECT DISTINCT emotion, topic_idx FROM stories_clean"
).fetchall()
conn.close()
return set(rows)
_PREAMBLE_RE = re.compile(
r'^(Here\s+are|Here\s+is|Below\s+are|These\s+are|The\s+following|I\'ve\s+written|Sure|Okay)',
re.IGNORECASE
)
def is_preamble(text):
"""Check if text is model preamble rather than an actual story."""
return bool(_PREAMBLE_RE.match(text.strip()))
def clean_story(text):
"""Strip leading markdown bold/headers and trailing junk."""
text = text.strip()
# Remove leading **Title** or ## Title lines
text = re.sub(r'^(?:\*\*[^*]+\*\*|#{1,3}\s+.+)\s*\n', '', text).strip()
return text
def parse_stories(text, expected_count=10):
"""Parse model output into individual stories."""
min_stories = max(2, expected_count // 2)
# Strategy 1: [story N] markers
parts = re.split(r'\[story\s*\d+\]', text, flags=re.IGNORECASE)
parts = [p.strip() for p in parts if p.strip() and len(p.strip()) > 50]
if len(parts) >= min_stories:
return parts
# Strategy 2: Numbered patterns
parts = re.split(r'(?:^|\n)\s*(?:\*{0,2}(?:Story\s+)?\d+[\.\):\*]{1,3}\s*\*{0,2})', text, flags=re.IGNORECASE)
parts = [p.strip() for p in parts if p.strip() and len(p.strip()) > 50]
if len(parts) >= min_stories:
return parts
# Strategy 3: Double newline separation
parts = re.split(r'\n\s*\n', text)
parts = [p.strip() for p in parts if p.strip() and len(p.strip()) > 50]
if len(parts) >= min_stories:
return parts
# Fallback
if len(text.strip()) > 100:
return [text.strip()]
return []
def generate_one(client, emotion, topic_idx, topic):
"""Generate stories for one emotion x topic, save to DB."""
prompt = STORY_PROMPT.format(
n_stories=STORIES_PER_CALL,
topic=topic,
emotion=emotion,
) + FORMAT_SUFFIX
db = get_db()
raw_response = None
error = None
status = "error"
try:
response = client.models.generate_content(
model=MODEL,
contents=prompt,
config=types.GenerateContentConfig(
temperature=0.9,
top_p=0.95,
top_k=64,
max_output_tokens=4096,
),
)
raw_response = response.text
stories = parse_stories(raw_response, STORIES_PER_CALL)
if not stories:
error = "no stories parsed"
status = "error"
else:
status = "done"
except Exception as e:
error = str(e)[:500]
stories = []
# Save to DB
with _db_lock:
try:
cursor = db.execute(
"""INSERT OR REPLACE INTO api_calls
(emotion, topic_idx, topic, raw_response, status, error)
VALUES (?, ?, ?, ?, ?, ?)""",
(emotion, topic_idx, topic, raw_response, status, error),
)
api_call_id = cursor.lastrowid
for i, story_text in enumerate(stories):
db.execute(
"""INSERT OR REPLACE INTO stories
(api_call_id, emotion, topic_idx, topic, story_idx, text)
VALUES (?, ?, ?, ?, ?, ?)""",
(api_call_id, emotion, topic_idx, topic, i, story_text),
)
# Write clean versions (skip preamble, clean formatting)
clean_idx = 0
for story_text in stories:
if is_preamble(story_text):
continue
cleaned = clean_story(story_text)
if len(cleaned) > 50:
db.execute(
"""INSERT OR REPLACE INTO stories_clean
(api_call_id, emotion, topic_idx, topic, story_idx, text)
VALUES (?, ?, ?, ?, ?, ?)""",
(api_call_id, emotion, topic_idx, topic, clean_idx, cleaned),
)
clean_idx += 1
db.commit()
except Exception as e:
db.rollback()
error = str(e)[:500]
return {
"emotion": emotion,
"topic_idx": topic_idx,
"n_stories": len(stories),
"status": status,
"error": error,
}
def backfill_clean():
"""Re-parse all existing stories into stories_clean table."""
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.execute("PRAGMA journal_mode=WAL")
# Get all api_calls that are done
calls = conn.execute(
"SELECT id, emotion, topic_idx, topic FROM api_calls WHERE status = 'done'"
).fetchall()
cleaned_total = 0
skipped_total = 0
for api_call_id, emotion, topic_idx, topic in calls:
rows = conn.execute(
"SELECT story_idx, text FROM stories WHERE api_call_id = ? ORDER BY story_idx",
(api_call_id,)
).fetchall()
clean_idx = 0
for _, story_text in rows:
if is_preamble(story_text):
skipped_total += 1
continue
cleaned = clean_story(story_text)
if len(cleaned) > 50:
conn.execute(
"""INSERT OR REPLACE INTO stories_clean
(api_call_id, emotion, topic_idx, topic, story_idx, text)
VALUES (?, ?, ?, ?, ?, ?)""",
(api_call_id, emotion, topic_idx, topic, clean_idx, cleaned),
)
clean_idx += 1
cleaned_total += 1
conn.commit()
print(f"Backfill complete: {cleaned_total} clean stories, {skipped_total} preambles skipped")
conn.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--test", action="store_true", help="Single call test")
parser.add_argument("--workers", type=int, default=100, help="Concurrent workers")
parser.add_argument("--backfill", action="store_true", help="Backfill stories_clean from existing stories")
args = parser.parse_args()
init_db()
if args.backfill:
backfill_clean()
return
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
print("ERROR: GEMINI_API_KEY not found in .env")
return
completed = get_completed()
# Build work queue
tasks = []
for emotion in EMOTIONS:
topics = TOPICS
for ti, topic in enumerate(topics):
if (emotion, ti) not in completed:
tasks.append((emotion, ti, topic))
total = len(EMOTIONS) * len(TOPICS)
done = total - len(tasks)
if args.test:
tasks = tasks[:1]
print(f"TEST MODE: 1 call only")
print(f"=== Story Generation (Gemini API) ===")
print(f"Total: {total} calls ({STORIES_PER_CALL} stories each)")
print(f"Done: {done}, Remaining: {len(tasks)}")
print(f"Workers: {min(args.workers, len(tasks))}")
if not tasks:
print("All stories already generated.")
return
client = genai.Client(api_key=api_key)
errors = 0
total_stories = 0
workers = min(args.workers, len(tasks))
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = {
executor.submit(generate_one, client, emotion, ti, topic): (emotion, ti)
for emotion, ti, topic in tasks
}
with tqdm(total=len(tasks), desc="Generating", unit="call") as pbar:
for future in as_completed(futures):
result = future.result()
total_stories += result["n_stories"]
if result["status"] == "error":
errors += 1
pbar.update(1)
pbar.set_postfix(
stories=total_stories,
errors=errors,
rate=f"{total_stories/(pbar.n or 1)*STORIES_PER_CALL:.0f}/call"
)
# Summary
conn = sqlite3.connect(DB_PATH, timeout=30)
total_stories_db = conn.execute("SELECT COUNT(*) FROM stories").fetchone()[0]
total_clean_db = conn.execute("SELECT COUNT(*) FROM stories_clean").fetchone()[0]
total_calls_done = conn.execute("SELECT COUNT(*) FROM api_calls WHERE status='done'").fetchone()[0]
total_errors = conn.execute("SELECT COUNT(*) FROM api_calls WHERE status='error'").fetchone()[0]
conn.close()
print(f"\n=== COMPLETE ===")
print(f"API calls: {total_calls_done} done, {total_errors} errors")
print(f"Stories (raw): {total_stories_db}")
print(f"Stories (clean): {total_clean_db}")
print(f"DB: {DB_PATH}")
if __name__ == "__main__":
main()