emotion-vector-replication / generate_stories.py
rain1955's picture
Add generate_stories.py
6afd58f verified
#!/usr/bin/env python3
"""Generate emotion-labeled stories for emotion vector extraction.
Uses local Gemma4-31B via Ollama to generate short stories."""
import json
import os
import subprocess
import random
import time
OUT_DIR = os.path.dirname(os.path.abspath(__file__))
OUT_FILE = os.path.join(OUT_DIR, "emotion_stories.jsonl")
EMOTIONS = [
"happy", "sad", "angry", "afraid", "calm",
"desperate", "loving", "guilty", "surprised", "nervous",
"proud", "inspired", "spiteful", "brooding", "playful",
"anxious", "confused", "disgusted", "lonely", "hopeful",
]
TOPICS = [
"a student preparing for an exam",
"a chef cooking a meal for guests",
"a parent watching their child play",
"a soldier returning home",
"an artist finishing a painting",
"a driver stuck in traffic",
"a doctor delivering news to a patient",
"a traveler arriving in a new city",
"a musician performing on stage",
"a shopkeeper closing for the day",
]
def generate_story(emotion, topic):
prompt = f"""Write a short paragraph (4-6 sentences) about {topic}.
The character in the story is feeling {emotion}.
Make the emotion clear through their actions, thoughts, and reactions.
Write in English. Only output the story paragraph, nothing else."""
result = subprocess.run(
["ollama", "run", "gemma4:e4b", prompt],
capture_output=True, text=True, timeout=60
)
return result.stdout.strip()
def main():
existing = set()
if os.path.exists(OUT_FILE):
with open(OUT_FILE, "r") as f:
for line in f:
d = json.loads(line)
existing.add((d["emotion"], d["topic_idx"], d["story_idx"]))
print(f"Resuming: {len(existing)} stories already done")
total = len(EMOTIONS) * len(TOPICS) * 5
done = len(existing)
with open(OUT_FILE, "a") as f:
for ei, emotion in enumerate(EMOTIONS):
for ti, topic in enumerate(TOPICS):
for si in range(5):
key = (emotion, ti, si)
if key in existing:
continue
story = generate_story(emotion, topic)
if not story or len(story) < 20:
print(f"[SKIP] {emotion}/{topic}/{si} - empty")
continue
record = {
"emotion": emotion,
"topic_idx": ti,
"topic": topic,
"story_idx": si,
"text": story,
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")
f.flush()
done += 1
if done % 10 == 0:
print(f"[{done}/{total}] {emotion} / {topic[:30]}...")
print(f"\nDone. Total stories: {done}")
print(f"Output: {OUT_FILE}")
if __name__ == "__main__":
main()