rain1955 commited on
Commit
6afd58f
·
verified ·
1 Parent(s): c339092

Add generate_stories.py

Browse files
Files changed (1) hide show
  1. generate_stories.py +91 -0
generate_stories.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate emotion-labeled stories for emotion vector extraction.
3
+ Uses local Gemma4-31B via Ollama to generate short stories."""
4
+
5
+ import json
6
+ import os
7
+ import subprocess
8
+ import random
9
+ import time
10
+
11
+ OUT_DIR = os.path.dirname(os.path.abspath(__file__))
12
+ OUT_FILE = os.path.join(OUT_DIR, "emotion_stories.jsonl")
13
+
14
+ EMOTIONS = [
15
+ "happy", "sad", "angry", "afraid", "calm",
16
+ "desperate", "loving", "guilty", "surprised", "nervous",
17
+ "proud", "inspired", "spiteful", "brooding", "playful",
18
+ "anxious", "confused", "disgusted", "lonely", "hopeful",
19
+ ]
20
+
21
+ TOPICS = [
22
+ "a student preparing for an exam",
23
+ "a chef cooking a meal for guests",
24
+ "a parent watching their child play",
25
+ "a soldier returning home",
26
+ "an artist finishing a painting",
27
+ "a driver stuck in traffic",
28
+ "a doctor delivering news to a patient",
29
+ "a traveler arriving in a new city",
30
+ "a musician performing on stage",
31
+ "a shopkeeper closing for the day",
32
+ ]
33
+
34
+ def generate_story(emotion, topic):
35
+ prompt = f"""Write a short paragraph (4-6 sentences) about {topic}.
36
+ The character in the story is feeling {emotion}.
37
+ Make the emotion clear through their actions, thoughts, and reactions.
38
+ Write in English. Only output the story paragraph, nothing else."""
39
+
40
+ result = subprocess.run(
41
+ ["ollama", "run", "gemma4:e4b", prompt],
42
+ capture_output=True, text=True, timeout=60
43
+ )
44
+ return result.stdout.strip()
45
+
46
+
47
+ def main():
48
+ existing = set()
49
+ if os.path.exists(OUT_FILE):
50
+ with open(OUT_FILE, "r") as f:
51
+ for line in f:
52
+ d = json.loads(line)
53
+ existing.add((d["emotion"], d["topic_idx"], d["story_idx"]))
54
+ print(f"Resuming: {len(existing)} stories already done")
55
+
56
+ total = len(EMOTIONS) * len(TOPICS) * 5
57
+ done = len(existing)
58
+
59
+ with open(OUT_FILE, "a") as f:
60
+ for ei, emotion in enumerate(EMOTIONS):
61
+ for ti, topic in enumerate(TOPICS):
62
+ for si in range(5):
63
+ key = (emotion, ti, si)
64
+ if key in existing:
65
+ continue
66
+
67
+ story = generate_story(emotion, topic)
68
+ if not story or len(story) < 20:
69
+ print(f"[SKIP] {emotion}/{topic}/{si} - empty")
70
+ continue
71
+
72
+ record = {
73
+ "emotion": emotion,
74
+ "topic_idx": ti,
75
+ "topic": topic,
76
+ "story_idx": si,
77
+ "text": story,
78
+ }
79
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
80
+ f.flush()
81
+ done += 1
82
+
83
+ if done % 10 == 0:
84
+ print(f"[{done}/{total}] {emotion} / {topic[:30]}...")
85
+
86
+ print(f"\nDone. Total stories: {done}")
87
+ print(f"Output: {OUT_FILE}")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ main()