jsakshi commited on
Commit
f522662
·
verified ·
1 Parent(s): 2178b75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -55
app.py CHANGED
@@ -589,54 +589,79 @@ demo.launch()'''
589
 
590
 
591
  import gradio as gr
592
- from transformers import pipeline
 
593
 
594
- # Load a lightweight model (distilgpt2) for initial story generation
595
- generator = pipeline("text-generation", model="distilgpt2", max_length=50, truncation=True)
 
 
596
 
597
  # Initialize story state
598
  story_history = []
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  def generate_initial_story(prompt):
601
  """Generate a coherent initial story paragraph from the prompt."""
602
  story_prompt = (
603
- f"Write a short paragraph (30-40 words) introducing a story based on: '{prompt}'. "
604
- "Keep it simple, coherent, and descriptive."
605
  )
606
- story_output = generator(story_prompt, max_length=50, num_return_sequences=1, do_sample=True, temperature=0.7)[0]["generated_text"]
607
- story_output = story_output.replace(story_prompt, "").strip()
608
 
609
- # Ensure prompt relevance
610
- prompt_words = set(prompt.lower().split())
611
- if not any(word in story_output.lower() for word in prompt_words):
612
  story_output = f"{prompt}. {story_output}"
613
 
614
  return story_output
615
 
616
- def generate_options(story):
617
- """Generate three specific, relevant options based on the story."""
618
- # Extract key elements from the story for relevance
619
- story_lower = story.lower()
620
- character = story.split()[0] # Assume first word is the character (e.g., "Priya")
 
 
 
621
 
622
- # Define option templates based on common story beats
623
- if "town" in story_lower or "village" in story_lower:
624
- options = [
625
- f"{character} moves to a big city for a new job opportunity.",
626
- f"{character} meets someone special who changes her life.",
627
- f"{character} discovers a family secret that shocks her."
628
- ]
629
- elif "city" in story_lower:
630
- options = [
631
- f"{character} lands a dream job that tests her skills.",
632
- f"{character} encounters a mysterious stranger in the crowd.",
633
- f"{character} finds an old letter in her apartment."
634
- ]
635
- else:
636
  options = [
637
- f"{character} embarks on an unexpected adventure.",
638
- f"{character} meets a friend who shares a bold idea.",
639
- f"{character} uncovers a hidden truth about her past."
640
  ]
641
 
642
  return options
@@ -645,26 +670,39 @@ def start_story(initial_prompt):
645
  """Start the story with the user's initial prompt."""
646
  global story_history
647
  story_history = [initial_prompt]
 
 
648
  segment = generate_initial_story(initial_prompt)
649
  story_history.append(segment)
 
 
650
  options = generate_options(segment)
651
- return f"{segment}\n\n🔹 What happens next? Choose an option:\n1️ {options[0]}\n2️ {options[1]}\n3️ {options[2]}", options[0], options[1], options[2]
 
652
 
653
  def continue_story(choice):
654
  """Continue the story based on the user's choice."""
655
  global story_history
656
- story_history.append(f"You chose: {choice}")
657
- # Generate a new segment based on the choice
658
- story_prompt = (
659
- f"Continue this story: '{story_history[-1]} {choice}'. "
660
- "Write a short paragraph (30-40 words) that builds on it logically."
 
 
661
  )
662
- story_output = generator(story_prompt, max_length=50, num_return_sequences=1, do_sample=True, temperature=0.7)[0]["generated_text"]
663
- story_output = story_output.replace(story_prompt, "").strip()
664
 
665
- story_history.append(story_output)
666
- options = generate_options(story_output)
667
- return f"{'\\n'.join(story_history)}\n\n🔹 What happens next? Choose an option:\n1️ {options[0]}\n2️ {options[1]}\n3️ {options[2]}", options[0], options[1], options[2]
 
 
 
 
 
 
 
 
668
 
669
  def reset_story():
670
  """Reset the story to start fresh."""
@@ -674,26 +712,38 @@ def reset_story():
674
 
675
  # Gradio interface
676
  with gr.Blocks(title="Story Adventure") as demo:
677
- gr.Markdown("# Story Adventure")
678
- gr.Markdown("Begin your tale and shape it with your choices!")
679
 
680
  with gr.Row():
681
- prompt_input = gr.Textbox(label="Enter your story prompt (e.g., 'Priya lives in a small town')", placeholder="Type here...")
682
- start_button = gr.Button("Start Story")
 
 
 
 
 
 
683
 
684
- story_output = gr.Textbox(label="Your Story", lines=10, interactive=False)
685
- choice_button_1 = gr.Button(value="", visible=False)
686
- choice_button_2 = gr.Button(value="", visible=False)
687
- choice_button_3 = gr.Button(value="", visible=False)
688
 
689
- reset_button = gr.Button("Reset Story")
690
 
 
691
  start_button.click(
692
  fn=start_story,
693
  inputs=prompt_input,
694
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
695
  ).then(
696
- fn=lambda story, c1, c2, c3: (story, gr.update(value=c1, visible=True), gr.update(value=c2, visible=True), gr.update(value=c3, visible=True)),
 
 
 
 
 
697
  inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
698
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
699
  )
@@ -704,7 +754,12 @@ with gr.Blocks(title="Story Adventure") as demo:
704
  inputs=button,
705
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
706
  ).then(
707
- fn=lambda story, c1, c2, c3: (story, gr.update(value=c1, visible=True), gr.update(value=c2, visible=True), gr.update(value=c3, visible=True)),
 
 
 
 
 
708
  inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
709
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
710
  )
@@ -714,4 +769,5 @@ with gr.Blocks(title="Story Adventure") as demo:
714
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
715
  )
716
 
717
- demo.launch()
 
 
589
 
590
 
591
  import gradio as gr
592
+ from transformers import AutoTokenizer, AutoModelForCausalLM
593
+ import torch
594
 
595
+ # Load TinyLlama model and tokenizer
596
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat"
597
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
598
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
599
 
600
  # Initialize story state
601
  story_history = []
602
 
603
+ def generate_text(prompt, max_length=100):
604
+ """Generate text using TinyLlama model with optimized settings."""
605
+ # Format prompt for chat
606
+ formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
607
+
608
+ # Tokenize input
609
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
610
+
611
+ # Generate with optimized parameters
612
+ with torch.no_grad():
613
+ outputs = model.generate(
614
+ inputs.input_ids,
615
+ max_new_tokens=max_length,
616
+ do_sample=True,
617
+ temperature=0.7,
618
+ top_p=0.9,
619
+ num_return_sequences=1,
620
+ pad_token_id=tokenizer.eos_token_id
621
+ )
622
+
623
+ # Decode and clean up response
624
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
625
+ response = response.replace(formatted_prompt, "").strip()
626
+ return response
627
+
628
  def generate_initial_story(prompt):
629
  """Generate a coherent initial story paragraph from the prompt."""
630
  story_prompt = (
631
+ f"Write a short, engaging paragraph (40-50 words) for a story starting with: '{prompt}'. "
632
+ "Make it descriptive and ensure it relates directly to the prompt."
633
  )
634
+ story_output = generate_text(story_prompt, max_length=75)
 
635
 
636
+ # Ensure the story starts with the prompt theme
637
+ if not any(word in story_output.lower() for word in prompt.lower().split()):
 
638
  story_output = f"{prompt}. {story_output}"
639
 
640
  return story_output
641
 
642
+ def generate_options(story_so_far):
643
+ """Generate three context-aware options based on the current story."""
644
+ option_prompt = (
645
+ f"Based on this story: '{story_so_far}'\n"
646
+ "Generate 3 distinct, exciting options for what happens next. "
647
+ "Each option should be a single sentence and directly relate to the story's context. "
648
+ "Format: 1. [Option 1] 2. [Option 2] 3. [Option 3]"
649
+ )
650
 
651
+ options_text = generate_text(option_prompt, max_length=100)
652
+
653
+ # Parse options or provide fallback options if generation fails
654
+ try:
655
+ options = [opt.split(". ")[1].strip() for opt in options_text.split("\n") if opt.strip().startswith(("1.", "2.", "3."))]
656
+ if len(options) != 3:
657
+ raise ValueError("Invalid options generated")
658
+ except:
659
+ # Fallback options based on common story elements
660
+ character = story_so_far.split()[0]
 
 
 
 
661
  options = [
662
+ f"{character} faces an unexpected challenge that tests their resolve.",
663
+ f"{character} discovers something that changes everything.",
664
+ f"{character} makes a decision that leads to adventure."
665
  ]
666
 
667
  return options
 
670
  """Start the story with the user's initial prompt."""
671
  global story_history
672
  story_history = [initial_prompt]
673
+
674
+ # Generate initial segment
675
  segment = generate_initial_story(initial_prompt)
676
  story_history.append(segment)
677
+
678
+ # Generate options
679
  options = generate_options(segment)
680
+
681
+ return format_story_output(segment, options), *options
682
 
683
  def continue_story(choice):
684
  """Continue the story based on the user's choice."""
685
  global story_history
686
+ story_history.append(f"Choice: {choice}")
687
+
688
+ # Generate new segment
689
+ continuation_prompt = (
690
+ f"Continue this story coherently:\n{story_history[-2]}\n"
691
+ f"The chosen path: {choice}\n"
692
+ "Write a natural 40-50 word continuation that builds on these events."
693
  )
 
 
694
 
695
+ new_segment = generate_text(continuation_prompt, max_length=75)
696
+ story_history.append(new_segment)
697
+
698
+ # Generate new options
699
+ options = generate_options(new_segment)
700
+
701
+ return format_story_output("\n".join(story_history), options), *options
702
+
703
+ def format_story_output(story, options):
704
+ """Format the story and options for display."""
705
+ return f"{story}\n\n🔹 What happens next?\n1️⃣ {options[0]}\n2️⃣ {options[1]}\n3️⃣ {options[2]}"
706
 
707
  def reset_story():
708
  """Reset the story to start fresh."""
 
712
 
713
  # Gradio interface
714
  with gr.Blocks(title="Story Adventure") as demo:
715
+ gr.Markdown("# 📚 Story Adventure")
716
+ gr.Markdown("Create your own adventure! Enter a prompt and shape the story through your choices.")
717
 
718
  with gr.Row():
719
+ prompt_input = gr.Textbox(
720
+ label="Story Prompt",
721
+ placeholder="e.g., 'Sarah discovers a mysterious letter in her attic'",
722
+ lines=2
723
+ )
724
+ start_button = gr.Button("🎮 Start Story", variant="primary")
725
+
726
+ story_output = gr.Textbox(label="Your Story", lines=12, interactive=False)
727
 
728
+ with gr.Row():
729
+ choice_button_1 = gr.Button(value="", visible=False, variant="secondary")
730
+ choice_button_2 = gr.Button(value="", visible=False, variant="secondary")
731
+ choice_button_3 = gr.Button(value="", visible=False, variant="secondary")
732
 
733
+ reset_button = gr.Button("🔄 Reset Story", variant="stop")
734
 
735
+ # Event handlers
736
  start_button.click(
737
  fn=start_story,
738
  inputs=prompt_input,
739
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
740
  ).then(
741
+ fn=lambda story, c1, c2, c3: (
742
+ story,
743
+ gr.update(value=c1, visible=True),
744
+ gr.update(value=c2, visible=True),
745
+ gr.update(value=c3, visible=True)
746
+ ),
747
  inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
748
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
749
  )
 
754
  inputs=button,
755
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
756
  ).then(
757
+ fn=lambda story, c1, c2, c3: (
758
+ story,
759
+ gr.update(value=c1, visible=True),
760
+ gr.update(value=c2, visible=True),
761
+ gr.update(value=c3, visible=True)
762
+ ),
763
  inputs=[story_output, choice_button_1, choice_button_2, choice_button_3],
764
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
765
  )
 
769
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
770
  )
771
 
772
+ if __name__ == "__main__":
773
+ demo.launch()