Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -589,54 +589,79 @@ demo.launch()'''
|
|
| 589 |
|
| 590 |
|
| 591 |
import gradio as gr
|
| 592 |
-
from transformers import
|
|
|
|
| 593 |
|
| 594 |
-
# Load
|
| 595 |
-
|
|
|
|
|
|
|
| 596 |
|
| 597 |
# Initialize story state
|
| 598 |
story_history = []
|
| 599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
def generate_initial_story(prompt):
|
| 601 |
"""Generate a coherent initial story paragraph from the prompt."""
|
| 602 |
story_prompt = (
|
| 603 |
-
f"Write a short paragraph (
|
| 604 |
-
"
|
| 605 |
)
|
| 606 |
-
story_output =
|
| 607 |
-
story_output = story_output.replace(story_prompt, "").strip()
|
| 608 |
|
| 609 |
-
# Ensure prompt
|
| 610 |
-
|
| 611 |
-
if not any(word in story_output.lower() for word in prompt_words):
|
| 612 |
story_output = f"{prompt}. {story_output}"
|
| 613 |
|
| 614 |
return story_output
|
| 615 |
|
| 616 |
-
def generate_options(
|
| 617 |
-
"""Generate three
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
options
|
| 631 |
-
|
| 632 |
-
f"{character} encounters a mysterious stranger in the crowd.",
|
| 633 |
-
f"{character} finds an old letter in her apartment."
|
| 634 |
-
]
|
| 635 |
-
else:
|
| 636 |
options = [
|
| 637 |
-
f"{character}
|
| 638 |
-
f"{character}
|
| 639 |
-
f"{character}
|
| 640 |
]
|
| 641 |
|
| 642 |
return options
|
|
@@ -645,26 +670,39 @@ def start_story(initial_prompt):
|
|
| 645 |
"""Start the story with the user's initial prompt."""
|
| 646 |
global story_history
|
| 647 |
story_history = [initial_prompt]
|
|
|
|
|
|
|
| 648 |
segment = generate_initial_story(initial_prompt)
|
| 649 |
story_history.append(segment)
|
|
|
|
|
|
|
| 650 |
options = generate_options(segment)
|
| 651 |
-
|
|
|
|
| 652 |
|
| 653 |
def continue_story(choice):
|
| 654 |
"""Continue the story based on the user's choice."""
|
| 655 |
global story_history
|
| 656 |
-
story_history.append(f"
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
"
|
|
|
|
|
|
|
| 661 |
)
|
| 662 |
-
story_output = generator(story_prompt, max_length=50, num_return_sequences=1, do_sample=True, temperature=0.7)[0]["generated_text"]
|
| 663 |
-
story_output = story_output.replace(story_prompt, "").strip()
|
| 664 |
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
def reset_story():
|
| 670 |
"""Reset the story to start fresh."""
|
|
@@ -674,26 +712,38 @@ def reset_story():
|
|
| 674 |
|
| 675 |
# Gradio interface
|
| 676 |
with gr.Blocks(title="Story Adventure") as demo:
|
| 677 |
-
gr.Markdown("# Story Adventure")
|
| 678 |
-
gr.Markdown("
|
| 679 |
|
| 680 |
with gr.Row():
|
| 681 |
-
prompt_input = gr.Textbox(
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
|
| 689 |
-
reset_button = gr.Button("Reset Story")
|
| 690 |
|
|
|
|
| 691 |
start_button.click(
|
| 692 |
fn=start_story,
|
| 693 |
inputs=prompt_input,
|
| 694 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 695 |
).then(
|
| 696 |
-
fn=lambda story, c1, c2, c3: (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
|
| 698 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 699 |
)
|
|
@@ -704,7 +754,12 @@ with gr.Blocks(title="Story Adventure") as demo:
|
|
| 704 |
inputs=button,
|
| 705 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 706 |
).then(
|
| 707 |
-
fn=lambda story, c1, c2, c3: (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
|
| 709 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 710 |
)
|
|
@@ -714,4 +769,5 @@ with gr.Blocks(title="Story Adventure") as demo:
|
|
| 714 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 715 |
)
|
| 716 |
|
| 717 |
-
|
|
|
|
|
|
| 589 |
|
| 590 |
|
| 591 |
import gradio as gr
|
| 592 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 593 |
+
import torch
|
| 594 |
|
| 595 |
+
# Load TinyLlama model and tokenizer
|
| 596 |
+
model_name = "TinyLlama/TinyLlama-1.1B-Chat"
|
| 597 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 598 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
|
| 599 |
|
| 600 |
# Initialize story state
|
| 601 |
story_history = []
|
| 602 |
|
| 603 |
+
def generate_text(prompt, max_length=100):
|
| 604 |
+
"""Generate text using TinyLlama model with optimized settings."""
|
| 605 |
+
# Format prompt for chat
|
| 606 |
+
formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
| 607 |
+
|
| 608 |
+
# Tokenize input
|
| 609 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
| 610 |
+
|
| 611 |
+
# Generate with optimized parameters
|
| 612 |
+
with torch.no_grad():
|
| 613 |
+
outputs = model.generate(
|
| 614 |
+
inputs.input_ids,
|
| 615 |
+
max_new_tokens=max_length,
|
| 616 |
+
do_sample=True,
|
| 617 |
+
temperature=0.7,
|
| 618 |
+
top_p=0.9,
|
| 619 |
+
num_return_sequences=1,
|
| 620 |
+
pad_token_id=tokenizer.eos_token_id
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# Decode and clean up response
|
| 624 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 625 |
+
response = response.replace(formatted_prompt, "").strip()
|
| 626 |
+
return response
|
| 627 |
+
|
| 628 |
def generate_initial_story(prompt):
|
| 629 |
"""Generate a coherent initial story paragraph from the prompt."""
|
| 630 |
story_prompt = (
|
| 631 |
+
f"Write a short, engaging paragraph (40-50 words) for a story starting with: '{prompt}'. "
|
| 632 |
+
"Make it descriptive and ensure it relates directly to the prompt."
|
| 633 |
)
|
| 634 |
+
story_output = generate_text(story_prompt, max_length=75)
|
|
|
|
| 635 |
|
| 636 |
+
# Ensure the story starts with the prompt theme
|
| 637 |
+
if not any(word in story_output.lower() for word in prompt.lower().split()):
|
|
|
|
| 638 |
story_output = f"{prompt}. {story_output}"
|
| 639 |
|
| 640 |
return story_output
|
| 641 |
|
| 642 |
+
def generate_options(story_so_far):
|
| 643 |
+
"""Generate three context-aware options based on the current story."""
|
| 644 |
+
option_prompt = (
|
| 645 |
+
f"Based on this story: '{story_so_far}'\n"
|
| 646 |
+
"Generate 3 distinct, exciting options for what happens next. "
|
| 647 |
+
"Each option should be a single sentence and directly relate to the story's context. "
|
| 648 |
+
"Format: 1. [Option 1] 2. [Option 2] 3. [Option 3]"
|
| 649 |
+
)
|
| 650 |
|
| 651 |
+
options_text = generate_text(option_prompt, max_length=100)
|
| 652 |
+
|
| 653 |
+
# Parse options or provide fallback options if generation fails
|
| 654 |
+
try:
|
| 655 |
+
options = [opt.split(". ")[1].strip() for opt in options_text.split("\n") if opt.strip().startswith(("1.", "2.", "3."))]
|
| 656 |
+
if len(options) != 3:
|
| 657 |
+
raise ValueError("Invalid options generated")
|
| 658 |
+
except:
|
| 659 |
+
# Fallback options based on common story elements
|
| 660 |
+
character = story_so_far.split()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
options = [
|
| 662 |
+
f"{character} faces an unexpected challenge that tests their resolve.",
|
| 663 |
+
f"{character} discovers something that changes everything.",
|
| 664 |
+
f"{character} makes a decision that leads to adventure."
|
| 665 |
]
|
| 666 |
|
| 667 |
return options
|
|
|
|
| 670 |
"""Start the story with the user's initial prompt."""
|
| 671 |
global story_history
|
| 672 |
story_history = [initial_prompt]
|
| 673 |
+
|
| 674 |
+
# Generate initial segment
|
| 675 |
segment = generate_initial_story(initial_prompt)
|
| 676 |
story_history.append(segment)
|
| 677 |
+
|
| 678 |
+
# Generate options
|
| 679 |
options = generate_options(segment)
|
| 680 |
+
|
| 681 |
+
return format_story_output(segment, options), *options
|
| 682 |
|
| 683 |
def continue_story(choice):
|
| 684 |
"""Continue the story based on the user's choice."""
|
| 685 |
global story_history
|
| 686 |
+
story_history.append(f"Choice: {choice}")
|
| 687 |
+
|
| 688 |
+
# Generate new segment
|
| 689 |
+
continuation_prompt = (
|
| 690 |
+
f"Continue this story coherently:\n{story_history[-2]}\n"
|
| 691 |
+
f"The chosen path: {choice}\n"
|
| 692 |
+
"Write a natural 40-50 word continuation that builds on these events."
|
| 693 |
)
|
|
|
|
|
|
|
| 694 |
|
| 695 |
+
new_segment = generate_text(continuation_prompt, max_length=75)
|
| 696 |
+
story_history.append(new_segment)
|
| 697 |
+
|
| 698 |
+
# Generate new options
|
| 699 |
+
options = generate_options(new_segment)
|
| 700 |
+
|
| 701 |
+
return format_story_output("\n".join(story_history), options), *options
|
| 702 |
+
|
| 703 |
+
def format_story_output(story, options):
|
| 704 |
+
"""Format the story and options for display."""
|
| 705 |
+
return f"{story}\n\n🔹 What happens next?\n1️⃣ {options[0]}\n2️⃣ {options[1]}\n3️⃣ {options[2]}"
|
| 706 |
|
| 707 |
def reset_story():
|
| 708 |
"""Reset the story to start fresh."""
|
|
|
|
| 712 |
|
| 713 |
# Gradio interface
|
| 714 |
with gr.Blocks(title="Story Adventure") as demo:
|
| 715 |
+
gr.Markdown("# 📚 Story Adventure")
|
| 716 |
+
gr.Markdown("Create your own adventure! Enter a prompt and shape the story through your choices.")
|
| 717 |
|
| 718 |
with gr.Row():
|
| 719 |
+
prompt_input = gr.Textbox(
|
| 720 |
+
label="Story Prompt",
|
| 721 |
+
placeholder="e.g., 'Sarah discovers a mysterious letter in her attic'",
|
| 722 |
+
lines=2
|
| 723 |
+
)
|
| 724 |
+
start_button = gr.Button("🎮 Start Story", variant="primary")
|
| 725 |
+
|
| 726 |
+
story_output = gr.Textbox(label="Your Story", lines=12, interactive=False)
|
| 727 |
|
| 728 |
+
with gr.Row():
|
| 729 |
+
choice_button_1 = gr.Button(value="", visible=False, variant="secondary")
|
| 730 |
+
choice_button_2 = gr.Button(value="", visible=False, variant="secondary")
|
| 731 |
+
choice_button_3 = gr.Button(value="", visible=False, variant="secondary")
|
| 732 |
|
| 733 |
+
reset_button = gr.Button("🔄 Reset Story", variant="stop")
|
| 734 |
|
| 735 |
+
# Event handlers
|
| 736 |
start_button.click(
|
| 737 |
fn=start_story,
|
| 738 |
inputs=prompt_input,
|
| 739 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 740 |
).then(
|
| 741 |
+
fn=lambda story, c1, c2, c3: (
|
| 742 |
+
story,
|
| 743 |
+
gr.update(value=c1, visible=True),
|
| 744 |
+
gr.update(value=c2, visible=True),
|
| 745 |
+
gr.update(value=c3, visible=True)
|
| 746 |
+
),
|
| 747 |
inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
|
| 748 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 749 |
)
|
|
|
|
| 754 |
inputs=button,
|
| 755 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 756 |
).then(
|
| 757 |
+
fn=lambda story, c1, c2, c3: (
|
| 758 |
+
story,
|
| 759 |
+
gr.update(value=c1, visible=True),
|
| 760 |
+
gr.update(value=c2, visible=True),
|
| 761 |
+
gr.update(value=c3, visible=True)
|
| 762 |
+
),
|
| 763 |
inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
|
| 764 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 765 |
)
|
|
|
|
| 769 |
outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
|
| 770 |
)
|
| 771 |
|
| 772 |
+
if __name__ == "__main__":
|
| 773 |
+
demo.launch()
|