File size: 3,047 Bytes
07e9fc6
 
da08414
47c889a
da08414
ec2c283
07e9fc6
51194d5
 
ec2c283
 
51194d5
 
896a6d3
 
 
 
 
 
ec2c283
896a6d3
 
51194d5
896a6d3
 
51194d5
ec2c283
 
 
 
cbeffa4
 
ec2c283
cbeffa4
ec2c283
cbeffa4
 
 
 
 
 
ec2c283
 
 
 
 
51194d5
ec2c283
896a6d3
 
ec2c283
896a6d3
ec2c283
896a6d3
 
 
 
 
 
 
 
 
ec2c283
 
896a6d3
 
51194d5
ec2c283
51194d5
f1c133b
ec2c283
 
51194d5
f1c133b
51194d5
 
 
896a6d3
51194d5
 
 
 
 
 
 
 
 
 
 
ec2c283
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# --- Load tokenizer & model ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",     # use GPU if available
    torch_dtype="auto"     # pick the best dtype
)

# Ensure pad token is set for safe generation
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
if getattr(model.config, "pad_token_id", None) is None:
    model.config.pad_token_id = tokenizer.pad_token_id

# --- Pipeline ---
generator = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=False
)

# --- Prompt builder (renamed to avoid shadowing) ---
def build_prompt(user_text: str) -> str:
    user_text = (user_text or "").strip()
    messages = [
        {"role": "system", "content": "You are a helpful storyteller that writes engaging prose."},
        {"role": "user", "content": user_text}
    ]
    # Use chat template if available, else fallback
    if hasattr(tokenizer, "apply_chat_template"):
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    return (
        "System: You are a helpful storyteller that writes engaging prose.\n"
        f"User: {user_text}\n"
        "Assistant:"
    )

# --- Generation function ---
def generate_story(prompt, max_tokens=300, temperature=0.8):
    try:
        prompt_str = build_prompt(prompt)  # <<< FIX: no name shadowing
        outputs = generator(
            prompt_str,
            max_new_tokens=int(max_tokens),
            temperature=float(temperature),
            do_sample=True,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        # pipelines return a list of dicts with "generated_text"
        return outputs[0].get("generated_text", "")
    except Exception as e:
        return f"Error during generation: {type(e).__name__}: {e}"

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("# 📖 Interactive Story Generator (TinyLlama/TinyLlama-1.1B-Chat-v1.0)")
    gr.Markdown("Type a prompt and let the AI continue your story with a compact chat model.")

    prompt = gr.Textbox(
        label="My Story Prompt",
        placeholder="e.g., In the far future, humanity discovered a hidden planet...",
        lines=3
    )
    max_length = gr.Slider(50, 1000, value=300, step=50, label="Story Length in new tokens")
    temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Creativity")

    generate_btn = gr.Button("✨ Generate Story")
    output = gr.Textbox(label="Generated Story", lines=20)

    generate_btn.click(
        fn=generate_story,
        inputs=[prompt, max_length, temperature],
        outputs=output
    )

demo.launch()