File size: 4,072 Bytes
8f0cbfc
3e15764
 
8f0cbfc
3e15764
47740c1
bf0d6fa
8f0cbfc
3e15764
 
 
 
 
 
 
 
8f0cbfc
 
bf0d6fa
31999e5
f903acf
 
bf0d6fa
 
f903acf
 
 
bf0d6fa
 
f903acf
 
 
bf0d6fa
8f0cbfc
 
3e15764
 
 
 
 
 
 
 
 
 
 
31999e5
3e15764
bf0d6fa
3e15764
 
 
 
 
 
47740c1
3e15764
 
8f0cbfc
 
3e15764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2300292
8f0cbfc
 
 
bf0d6fa
 
3e15764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f0cbfc
 
31999e5
3e15764
 
8f0cbfc
 
 
47740c1
8f0cbfc
 
bf0d6fa
8f0cbfc
3e15764
bf0d6fa
 
8f0cbfc
bf0d6fa
 
 
 
 
 
 
 
 
 
 
8f0cbfc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import torch

MODEL_NAME = "Qwen/Qwen1.5-1.8B-Chat"

set_seed(42)

# Load tokenizer & model correctly
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    device_map="auto"
)

def get_actions(theme):
    return {
        "Fantasy": [
            "open the glowing door",
            "follow the floating lanterns",
            "climb the shimmering tree"
        ],
        "Sci-Fi": [
            "press the glowing button",
            "talk to the friendly robot",
            "enter the humming pod"
        ],
        "Mystery": [
            "read the hidden note",
            "peek through the dusty window",
            "open the old wooden box"
        ],
    }[theme]

def generate_story(messages):
    # Apply proper chat template
    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt"
    ).to(model.device)

    output_ids = model.generate(
        input_ids,
        max_new_tokens=400,        # LONG story
        temperature=0.95,          # creative but calm
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
    )

    output = tokenizer.decode(
        output_ids[0][input_ids.shape[-1]:],
        skip_special_tokens=True
    )

    return output.strip()

def start_story(name, theme):
    messages = [
        {
            "role": "system",
            "content": (
                "You are a gentle, imaginative storyteller.\n"
                "Write ONLY the story.\n"
                "Never ask questions.\n"
                "Never request feedback.\n"
                "No violence, threats, or dark content.\n"
                "Keep the tone cozy, magical, and adventurous."
            )
        },
        {
            "role": "user",
            "content": (
                f"Begin a vivid {theme} story.\n"
                f"The main character is named {name}.\n"
                "Use rich descriptions and warm emotions."
            )
        }
    ]

    story = generate_story(messages)
    return story, gr.update(choices=get_actions(theme), visible=True), ""

def next_step(name, theme, action):
    if not action:
        return "Please select an action to continue the story.", ""

    messages = [
        {
            "role": "system",
            "content": (
                "You are continuing a cozy, imaginative story.\n"
                "Write ONLY the story.\n"
                "No violence or threatening language."
            )
        },
        {
            "role": "user",
            "content": (
                f"The story is set in a {theme} world.\n"
                f"The main character is {name}.\n"
                f"{name} chooses to {action}.\n"
                "Continue the story with wonder and detail."
            )
        }
    ]

    story = generate_story(messages)
    return story, ""

with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo")) as demo:
    gr.Markdown("# 📖 Cozy Creative Story Generator")
    gr.Markdown("Safe, imaginative, long-form storytelling ✨")

    with gr.Row():
        with gr.Column():
            name = gr.Textbox(label="Your Name", value="You")
            theme = gr.Radio(["Fantasy", "Sci-Fi", "Mystery"], label="Choose a Theme")
            start_btn = gr.Button("🚀 Start Story")

        with gr.Column():
            story_out = gr.Textbox(label="Story", lines=14, interactive=False)
            action_choice = gr.Dropdown(label="Your Action", choices=[], visible=False)
            continue_btn = gr.Button("✨ Continue Story")

    start_btn.click(
        start_story,
        inputs=[name, theme],
        outputs=[story_out, action_choice, story_out],
    )

    continue_btn.click(
        next_step,
        inputs=[name, theme, action_choice],
        outputs=[story_out, action_choice],
    )

if __name__ == "__main__":
    demo.launch()