#!/usr/bin/env python3 """Generate emotion-labeled stories for emotion vector extraction. Uses local Gemma4-31B via Ollama to generate short stories.""" import json import os import subprocess import random import time OUT_DIR = os.path.dirname(os.path.abspath(__file__)) OUT_FILE = os.path.join(OUT_DIR, "emotion_stories.jsonl") EMOTIONS = [ "happy", "sad", "angry", "afraid", "calm", "desperate", "loving", "guilty", "surprised", "nervous", "proud", "inspired", "spiteful", "brooding", "playful", "anxious", "confused", "disgusted", "lonely", "hopeful", ] TOPICS = [ "a student preparing for an exam", "a chef cooking a meal for guests", "a parent watching their child play", "a soldier returning home", "an artist finishing a painting", "a driver stuck in traffic", "a doctor delivering news to a patient", "a traveler arriving in a new city", "a musician performing on stage", "a shopkeeper closing for the day", ] def generate_story(emotion, topic): prompt = f"""Write a short paragraph (4-6 sentences) about {topic}. The character in the story is feeling {emotion}. Make the emotion clear through their actions, thoughts, and reactions. Write in English. Only output the story paragraph, nothing else.""" result = subprocess.run( ["ollama", "run", "gemma4:e4b", prompt], capture_output=True, text=True, timeout=60 ) return result.stdout.strip() def main(): existing = set() if os.path.exists(OUT_FILE): with open(OUT_FILE, "r") as f: for line in f: d = json.loads(line) existing.add((d["emotion"], d["topic_idx"], d["story_idx"])) print(f"Resuming: {len(existing)} stories already done") total = len(EMOTIONS) * len(TOPICS) * 5 done = len(existing) with open(OUT_FILE, "a") as f: for ei, emotion in enumerate(EMOTIONS): for ti, topic in enumerate(TOPICS): for si in range(5): key = (emotion, ti, si) if key in existing: continue story = generate_story(emotion, topic) if not story or len(story) < 20: print(f"[SKIP] {emotion}/{topic}/{si} - empty") continue record = { "emotion": emotion, "topic_idx": ti, "topic": topic, "story_idx": si, "text": story, } f.write(json.dumps(record, ensure_ascii=False) + "\n") f.flush() done += 1 if done % 10 == 0: print(f"[{done}/{total}] {emotion} / {topic[:30]}...") print(f"\nDone. Total stories: {done}") print(f"Output: {OUT_FILE}") if __name__ == "__main__": main()