| |
| """Generate emotion-labeled stories using Gemini API. |
| |
| 171 emotions x 100 topics x 10 stories = 171,000 stories. |
| Concurrent API calls (up to 100), SQLite WAL for storage, |
| saves both raw API output and parsed stories. |
| |
| Run: |
| python -m full_replication.generate_stories |
| python -m full_replication.generate_stories --test |
| python -m full_replication.generate_stories --workers 50 |
| """ |
|
|
| import argparse |
| import os |
| import re |
| import sqlite3 |
| import threading |
| import time |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
| from dotenv import load_dotenv |
| from google import genai |
| from google.genai import types |
| from tqdm import tqdm |
|
|
| from full_replication.config import EMOTIONS, TOPICS, STORY_PROMPT |
|
|
| load_dotenv(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env")) |
|
|
| DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "stories.db") |
| MODEL = "gemini-2.0-flash-lite" |
| STORIES_PER_CALL = 10 |
|
|
| |
| FORMAT_SUFFIX = """ |
| |
| OUTPUT FORMAT: Start directly with [story 1] — no preamble, no introductions, no explanations, no commentary. Output ONLY the stories separated by [story N] markers. Nothing else.""" |
|
|
| |
| _local = threading.local() |
| _db_lock = threading.Lock() |
|
|
|
|
| def get_db(): |
| """Get thread-local DB connection with WAL mode.""" |
| if not hasattr(_local, "conn"): |
| _local.conn = sqlite3.connect(DB_PATH, timeout=30) |
| _local.conn.execute("PRAGMA journal_mode=WAL") |
| _local.conn.execute("PRAGMA busy_timeout=10000") |
| return _local.conn |
|
|
|
|
| def init_db(): |
| """Create tables if they don't exist.""" |
| os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) |
| conn = sqlite3.connect(DB_PATH) |
| conn.execute("PRAGMA journal_mode=WAL") |
| conn.executescript(""" |
| CREATE TABLE IF NOT EXISTS api_calls ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| emotion TEXT NOT NULL, |
| topic_idx INTEGER NOT NULL, |
| topic TEXT NOT NULL, |
| raw_response TEXT, |
| status TEXT DEFAULT 'pending', |
| error TEXT, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| UNIQUE(emotion, topic_idx) |
| ); |
| CREATE TABLE IF NOT EXISTS stories ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| api_call_id INTEGER NOT NULL, |
| emotion TEXT NOT NULL, |
| topic_idx INTEGER NOT NULL, |
| topic TEXT NOT NULL, |
| story_idx INTEGER NOT NULL, |
| text TEXT NOT NULL, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| FOREIGN KEY (api_call_id) REFERENCES api_calls(id), |
| UNIQUE(emotion, topic_idx, story_idx) |
| ); |
| CREATE INDEX IF NOT EXISTS idx_stories_emotion ON stories(emotion); |
| CREATE INDEX IF NOT EXISTS idx_api_calls_status ON api_calls(status); |
| CREATE TABLE IF NOT EXISTS stories_clean ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| api_call_id INTEGER NOT NULL, |
| emotion TEXT NOT NULL, |
| topic_idx INTEGER NOT NULL, |
| topic TEXT NOT NULL, |
| story_idx INTEGER NOT NULL, |
| text TEXT NOT NULL, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| FOREIGN KEY (api_call_id) REFERENCES api_calls(id), |
| UNIQUE(emotion, topic_idx, story_idx) |
| ); |
| CREATE INDEX IF NOT EXISTS idx_stories_clean_emotion ON stories_clean(emotion); |
| """) |
| conn.commit() |
| conn.close() |
|
|
|
|
| def get_completed(): |
| """Return set of (emotion, topic_idx) already in stories_clean.""" |
| conn = sqlite3.connect(DB_PATH, timeout=30) |
| conn.execute("PRAGMA journal_mode=WAL") |
| rows = conn.execute( |
| "SELECT DISTINCT emotion, topic_idx FROM stories_clean" |
| ).fetchall() |
| conn.close() |
| return set(rows) |
|
|
|
|
| _PREAMBLE_RE = re.compile( |
| r'^(Here\s+are|Here\s+is|Below\s+are|These\s+are|The\s+following|I\'ve\s+written|Sure|Okay)', |
| re.IGNORECASE |
| ) |
|
|
|
|
| def is_preamble(text): |
| """Check if text is model preamble rather than an actual story.""" |
| return bool(_PREAMBLE_RE.match(text.strip())) |
|
|
|
|
| def clean_story(text): |
| """Strip leading markdown bold/headers and trailing junk.""" |
| text = text.strip() |
| |
| text = re.sub(r'^(?:\*\*[^*]+\*\*|#{1,3}\s+.+)\s*\n', '', text).strip() |
| return text |
|
|
|
|
| def parse_stories(text, expected_count=10): |
| """Parse model output into individual stories.""" |
| min_stories = max(2, expected_count // 2) |
|
|
| |
| parts = re.split(r'\[story\s*\d+\]', text, flags=re.IGNORECASE) |
| parts = [p.strip() for p in parts if p.strip() and len(p.strip()) > 50] |
| if len(parts) >= min_stories: |
| return parts |
|
|
| |
| parts = re.split(r'(?:^|\n)\s*(?:\*{0,2}(?:Story\s+)?\d+[\.\):\*]{1,3}\s*\*{0,2})', text, flags=re.IGNORECASE) |
| parts = [p.strip() for p in parts if p.strip() and len(p.strip()) > 50] |
| if len(parts) >= min_stories: |
| return parts |
|
|
| |
| parts = re.split(r'\n\s*\n', text) |
| parts = [p.strip() for p in parts if p.strip() and len(p.strip()) > 50] |
| if len(parts) >= min_stories: |
| return parts |
|
|
| |
| if len(text.strip()) > 100: |
| return [text.strip()] |
| return [] |
|
|
|
|
| def generate_one(client, emotion, topic_idx, topic): |
| """Generate stories for one emotion x topic, save to DB.""" |
| prompt = STORY_PROMPT.format( |
| n_stories=STORIES_PER_CALL, |
| topic=topic, |
| emotion=emotion, |
| ) + FORMAT_SUFFIX |
|
|
| db = get_db() |
| raw_response = None |
| error = None |
| status = "error" |
|
|
| try: |
| response = client.models.generate_content( |
| model=MODEL, |
| contents=prompt, |
| config=types.GenerateContentConfig( |
| temperature=0.9, |
| top_p=0.95, |
| top_k=64, |
| max_output_tokens=4096, |
| ), |
| ) |
| raw_response = response.text |
| stories = parse_stories(raw_response, STORIES_PER_CALL) |
|
|
| if not stories: |
| error = "no stories parsed" |
| status = "error" |
| else: |
| status = "done" |
|
|
| except Exception as e: |
| error = str(e)[:500] |
| stories = [] |
|
|
| |
| with _db_lock: |
| try: |
| cursor = db.execute( |
| """INSERT OR REPLACE INTO api_calls |
| (emotion, topic_idx, topic, raw_response, status, error) |
| VALUES (?, ?, ?, ?, ?, ?)""", |
| (emotion, topic_idx, topic, raw_response, status, error), |
| ) |
| api_call_id = cursor.lastrowid |
|
|
| for i, story_text in enumerate(stories): |
| db.execute( |
| """INSERT OR REPLACE INTO stories |
| (api_call_id, emotion, topic_idx, topic, story_idx, text) |
| VALUES (?, ?, ?, ?, ?, ?)""", |
| (api_call_id, emotion, topic_idx, topic, i, story_text), |
| ) |
|
|
| |
| clean_idx = 0 |
| for story_text in stories: |
| if is_preamble(story_text): |
| continue |
| cleaned = clean_story(story_text) |
| if len(cleaned) > 50: |
| db.execute( |
| """INSERT OR REPLACE INTO stories_clean |
| (api_call_id, emotion, topic_idx, topic, story_idx, text) |
| VALUES (?, ?, ?, ?, ?, ?)""", |
| (api_call_id, emotion, topic_idx, topic, clean_idx, cleaned), |
| ) |
| clean_idx += 1 |
|
|
| db.commit() |
| except Exception as e: |
| db.rollback() |
| error = str(e)[:500] |
|
|
| return { |
| "emotion": emotion, |
| "topic_idx": topic_idx, |
| "n_stories": len(stories), |
| "status": status, |
| "error": error, |
| } |
|
|
|
|
| def backfill_clean(): |
| """Re-parse all existing stories into stories_clean table.""" |
| conn = sqlite3.connect(DB_PATH, timeout=30) |
| conn.execute("PRAGMA journal_mode=WAL") |
|
|
| |
| calls = conn.execute( |
| "SELECT id, emotion, topic_idx, topic FROM api_calls WHERE status = 'done'" |
| ).fetchall() |
|
|
| cleaned_total = 0 |
| skipped_total = 0 |
|
|
| for api_call_id, emotion, topic_idx, topic in calls: |
| rows = conn.execute( |
| "SELECT story_idx, text FROM stories WHERE api_call_id = ? ORDER BY story_idx", |
| (api_call_id,) |
| ).fetchall() |
|
|
| clean_idx = 0 |
| for _, story_text in rows: |
| if is_preamble(story_text): |
| skipped_total += 1 |
| continue |
| cleaned = clean_story(story_text) |
| if len(cleaned) > 50: |
| conn.execute( |
| """INSERT OR REPLACE INTO stories_clean |
| (api_call_id, emotion, topic_idx, topic, story_idx, text) |
| VALUES (?, ?, ?, ?, ?, ?)""", |
| (api_call_id, emotion, topic_idx, topic, clean_idx, cleaned), |
| ) |
| clean_idx += 1 |
| cleaned_total += 1 |
|
|
| conn.commit() |
| print(f"Backfill complete: {cleaned_total} clean stories, {skipped_total} preambles skipped") |
| conn.close() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--test", action="store_true", help="Single call test") |
| parser.add_argument("--workers", type=int, default=100, help="Concurrent workers") |
| parser.add_argument("--backfill", action="store_true", help="Backfill stories_clean from existing stories") |
| args = parser.parse_args() |
|
|
| init_db() |
|
|
| if args.backfill: |
| backfill_clean() |
| return |
|
|
| api_key = os.environ.get("GEMINI_API_KEY") |
| if not api_key: |
| print("ERROR: GEMINI_API_KEY not found in .env") |
| return |
| completed = get_completed() |
|
|
| |
| tasks = [] |
| for emotion in EMOTIONS: |
| topics = TOPICS |
| for ti, topic in enumerate(topics): |
| if (emotion, ti) not in completed: |
| tasks.append((emotion, ti, topic)) |
|
|
| total = len(EMOTIONS) * len(TOPICS) |
| done = total - len(tasks) |
|
|
| if args.test: |
| tasks = tasks[:1] |
| print(f"TEST MODE: 1 call only") |
| print(f"=== Story Generation (Gemini API) ===") |
| print(f"Total: {total} calls ({STORIES_PER_CALL} stories each)") |
| print(f"Done: {done}, Remaining: {len(tasks)}") |
| print(f"Workers: {min(args.workers, len(tasks))}") |
|
|
| if not tasks: |
| print("All stories already generated.") |
| return |
|
|
| client = genai.Client(api_key=api_key) |
|
|
| errors = 0 |
| total_stories = 0 |
| workers = min(args.workers, len(tasks)) |
|
|
| with ThreadPoolExecutor(max_workers=workers) as executor: |
| futures = { |
| executor.submit(generate_one, client, emotion, ti, topic): (emotion, ti) |
| for emotion, ti, topic in tasks |
| } |
|
|
| with tqdm(total=len(tasks), desc="Generating", unit="call") as pbar: |
| for future in as_completed(futures): |
| result = future.result() |
| total_stories += result["n_stories"] |
| if result["status"] == "error": |
| errors += 1 |
| pbar.update(1) |
| pbar.set_postfix( |
| stories=total_stories, |
| errors=errors, |
| rate=f"{total_stories/(pbar.n or 1)*STORIES_PER_CALL:.0f}/call" |
| ) |
|
|
| |
| conn = sqlite3.connect(DB_PATH, timeout=30) |
| total_stories_db = conn.execute("SELECT COUNT(*) FROM stories").fetchone()[0] |
| total_clean_db = conn.execute("SELECT COUNT(*) FROM stories_clean").fetchone()[0] |
| total_calls_done = conn.execute("SELECT COUNT(*) FROM api_calls WHERE status='done'").fetchone()[0] |
| total_errors = conn.execute("SELECT COUNT(*) FROM api_calls WHERE status='error'").fetchone()[0] |
| conn.close() |
|
|
| print(f"\n=== COMPLETE ===") |
| print(f"API calls: {total_calls_done} done, {total_errors} errors") |
| print(f"Stories (raw): {total_stories_db}") |
| print(f"Stories (clean): {total_clean_db}") |
| print(f"DB: {DB_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|