jsakshi commited on
Commit
759b6eb
·
verified ·
1 Parent(s): f522662

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -122
app.py CHANGED
@@ -584,158 +584,206 @@ demo.launch()'''
584
 
585
 
586
 
587
-
588
-
589
-
590
-
591
  import gradio as gr
592
  from transformers import AutoTokenizer, AutoModelForCausalLM
593
  import torch
 
594
 
595
  # Load TinyLlama model and tokenizer
596
  model_name = "TinyLlama/TinyLlama-1.1B-Chat"
597
  tokenizer = AutoTokenizer.from_pretrained(model_name)
598
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
599
-
600
- # Initialize story state
601
- story_history = []
 
602
 
603
- def generate_text(prompt, max_length=100):
604
- """Generate text using TinyLlama model with optimized settings."""
605
- # Format prompt for chat
606
- formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
607
-
608
- # Tokenize input
609
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
610
-
611
- # Generate with optimized parameters
612
- with torch.no_grad():
613
- outputs = model.generate(
614
- inputs.input_ids,
615
- max_new_tokens=max_length,
616
- do_sample=True,
617
- temperature=0.7,
618
- top_p=0.9,
619
- num_return_sequences=1,
620
- pad_token_id=tokenizer.eos_token_id
621
- )
622
-
623
- # Decode and clean up response
624
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
625
- response = response.replace(formatted_prompt, "").strip()
626
- return response
627
 
628
- def generate_initial_story(prompt):
629
- """Generate a coherent initial story paragraph from the prompt."""
630
- story_prompt = (
631
- f"Write a short, engaging paragraph (40-50 words) for a story starting with: '{prompt}'. "
632
- "Make it descriptive and ensure it relates directly to the prompt."
633
- )
634
- story_output = generate_text(story_prompt, max_length=75)
635
-
636
- # Ensure the story starts with the prompt theme
637
- if not any(word in story_output.lower() for word in prompt.lower().split()):
638
- story_output = f"{prompt}. {story_output}"
639
-
640
- return story_output
641
 
642
- def generate_options(story_so_far):
643
- """Generate three context-aware options based on the current story."""
644
- option_prompt = (
645
- f"Based on this story: '{story_so_far}'\n"
646
- "Generate 3 distinct, exciting options for what happens next. "
647
- "Each option should be a single sentence and directly relate to the story's context. "
648
- "Format: 1. [Option 1] 2. [Option 2] 3. [Option 3]"
649
- )
650
-
651
- options_text = generate_text(option_prompt, max_length=100)
652
-
653
- # Parse options or provide fallback options if generation fails
654
  try:
655
- options = [opt.split(". ")[1].strip() for opt in options_text.split("\n") if opt.strip().startswith(("1.", "2.", "3."))]
656
- if len(options) != 3:
657
- raise ValueError("Invalid options generated")
658
- except:
659
- # Fallback options based on common story elements
660
- character = story_so_far.split()[0]
661
- options = [
662
- f"{character} faces an unexpected challenge that tests their resolve.",
663
- f"{character} discovers something that changes everything.",
664
- f"{character} makes a decision that leads to adventure."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
667
- return options
668
-
669
- def start_story(initial_prompt):
670
- """Start the story with the user's initial prompt."""
671
- global story_history
672
- story_history = [initial_prompt]
673
-
674
- # Generate initial segment
675
- segment = generate_initial_story(initial_prompt)
676
- story_history.append(segment)
677
 
678
- # Generate options
679
- options = generate_options(segment)
680
 
681
- return format_story_output(segment, options), *options
682
 
683
  def continue_story(choice):
684
- """Continue the story based on the user's choice."""
685
- global story_history
686
- story_history.append(f"Choice: {choice}")
687
-
688
- # Generate new segment
689
- continuation_prompt = (
690
- f"Continue this story coherently:\n{story_history[-2]}\n"
691
- f"The chosen path: {choice}\n"
692
- "Write a natural 40-50 word continuation that builds on these events."
 
 
693
  )
694
 
695
- new_segment = generate_text(continuation_prompt, max_length=75)
696
- story_history.append(new_segment)
697
 
698
- # Generate new options
699
- options = generate_options(new_segment)
700
 
701
- return format_story_output("\n".join(story_history), options), *options
702
 
703
- def format_story_output(story, options):
704
- """Format the story and options for display."""
705
- return f"{story}\n\n🔹 What happens next?\n1️⃣ {options[0]}\n2️⃣ {options[1]}\n3️⃣ {options[2]}"
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
- def reset_story():
708
- """Reset the story to start fresh."""
709
- global story_history
710
- story_history = []
711
- return "Story reset. Enter a new prompt to begin!", "", "", ""
712
 
713
  # Gradio interface
714
- with gr.Blocks(title="Story Adventure") as demo:
715
- gr.Markdown("# 📚 Story Adventure")
716
- gr.Markdown("Create your own adventure! Enter a prompt and shape the story through your choices.")
717
 
 
718
  with gr.Row():
719
- prompt_input = gr.Textbox(
720
- label="Story Prompt",
721
- placeholder="e.g., 'Sarah discovers a mysterious letter in her attic'",
722
- lines=2
723
- )
724
- start_button = gr.Button("🎮 Start Story", variant="primary")
 
 
 
 
725
 
726
- story_output = gr.Textbox(label="Your Story", lines=12, interactive=False)
 
 
 
 
 
727
 
728
  with gr.Row():
729
- choice_button_1 = gr.Button(value="", visible=False, variant="secondary")
730
- choice_button_2 = gr.Button(value="", visible=False, variant="secondary")
731
- choice_button_3 = gr.Button(value="", visible=False, variant="secondary")
732
 
733
- reset_button = gr.Button("🔄 Reset Story", variant="stop")
734
 
735
  # Event handlers
736
  start_button.click(
737
- fn=start_story,
738
- inputs=prompt_input,
739
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
740
  ).then(
741
  fn=lambda story, c1, c2, c3: (
@@ -751,7 +799,7 @@ with gr.Blocks(title="Story Adventure") as demo:
751
  for button in [choice_button_1, choice_button_2, choice_button_3]:
752
  button.click(
753
  fn=continue_story,
754
- inputs=button,
755
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
756
  ).then(
757
  fn=lambda story, c1, c2, c3: (
@@ -765,9 +813,23 @@ with gr.Blocks(title="Story Adventure") as demo:
765
  )
766
 
767
  reset_button.click(
768
- fn=reset_story,
769
- outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
 
 
 
 
 
 
 
 
770
  )
771
 
 
772
  if __name__ == "__main__":
773
- demo.launch()
 
 
 
 
 
 
584
 
585
 
586
 
 
 
 
 
587
  import gradio as gr
588
  from transformers import AutoTokenizer, AutoModelForCausalLM
589
  import torch
590
+ import random
591
 
592
  # Load TinyLlama model and tokenizer
593
  model_name = "TinyLlama/TinyLlama-1.1B-Chat"
594
  tokenizer = AutoTokenizer.from_pretrained(model_name)
595
+ model = AutoModelForCausalLM.from_pretrained(
596
+ model_name,
597
+ torch_dtype=torch.float16,
598
+ device_map="auto"
599
+ )
600
 
601
+ class StoryState:
602
+ def __init__(self):
603
+ self.history = []
604
+ self.player_name = ""
605
+ self.character_class = ""
606
+ self.current_location = ""
607
+ self.inventory = []
608
+ self.health = 100
609
+
610
+ def reset(self):
611
+ self.__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
612
 
613
+ # Initialize game state
614
+ game_state = StoryState()
 
 
 
 
 
 
 
 
 
 
 
615
 
616
+ def generate_story_segment(prompt, max_length=150):
617
+ """Generate story text using TinyLlama."""
 
 
 
 
 
 
 
 
 
 
618
  try:
619
+ # Format prompt for chat
620
+ formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
621
+
622
+ # Tokenize
623
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
624
+
625
+ # Generate with optimized parameters
626
+ with torch.no_grad():
627
+ outputs = model.generate(
628
+ inputs.input_ids,
629
+ max_new_tokens=max_length,
630
+ do_sample=True,
631
+ temperature=0.8,
632
+ top_p=0.9,
633
+ num_return_sequences=1,
634
+ pad_token_id=tokenizer.eos_token_id
635
+ )
636
+
637
+ # Decode and clean response
638
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
639
+ return response.replace(formatted_prompt, "").strip()
640
+ except Exception as e:
641
+ return get_fallback_segment()
642
+
643
+ def get_fallback_segment():
644
+ """Provide fallback story segments if generation fails."""
645
+ fallbacks = [
646
+ "As you proceed cautiously, the path ahead seems to hold both promise and peril.",
647
+ "The atmosphere grows tense as you consider your next move.",
648
+ "Your instincts tell you that this decision could change everything.",
649
+ "The consequences of your choices begin to unfold around you."
650
+ ]
651
+ return random.choice(fallbacks)
652
+
653
+ def generate_choices(context):
654
+ """Generate contextual choices based on the current story state."""
655
+ try:
656
+ prompt = (
657
+ f"Based on this story context: '{context}'\n"
658
+ "Generate 3 distinct, interesting choices for what the player could do next. "
659
+ "Each choice should be a single sentence and affect the story differently. "
660
+ "Format: 1. [Choice 1] 2. [Choice 2] 3. [Choice 3]"
661
+ )
662
+
663
+ choices_text = generate_story_segment(prompt, max_length=100)
664
+ choices = [
665
+ choice.split(". ")[1].strip()
666
+ for choice in choices_text.split("\n")
667
+ if choice.strip().startswith(("1.", "2.", "3."))
668
  ]
669
+
670
+ if len(choices) != 3:
671
+ raise ValueError("Invalid choices generated")
672
+
673
+ return choices
674
+ except:
675
+ return get_fallback_choices()
676
+
677
+ def get_fallback_choices():
678
+ """Provide fallback choices if generation fails."""
679
+ return [
680
+ "Proceed carefully and gather more information",
681
+ "Take bold action and face the consequences",
682
+ "Seek an alternative path forward"
683
+ ]
684
+
685
+ def initialize_character(name, character_class):
686
+ """Initialize the player character."""
687
+ if not name or not character_class:
688
+ return "Please enter both name and class!", None, None, None
689
+
690
+ game_state.reset()
691
+ game_state.player_name = name
692
+ game_state.character_class = character_class
693
+
694
+ # Generate initial story setup
695
+ prompt = (
696
+ f"Write an engaging opening paragraph for a fantasy adventure story where {name}, "
697
+ f"a brave {character_class}, begins their journey. Make it exciting and descriptive."
698
+ )
699
 
700
+ story_start = generate_story_segment(prompt)
701
+ game_state.history.append(story_start)
 
 
 
 
 
 
 
 
702
 
703
+ # Generate initial choices
704
+ choices = generate_choices(story_start)
705
 
706
+ return format_game_output(story_start, choices), *choices
707
 
708
  def continue_story(choice):
709
+ """Continue the story based on player's choice."""
710
+ if not choice or not game_state.history:
711
+ return "Please start a new game!", None, None, None
712
+
713
+ game_state.history.append(f"You chose to: {choice}")
714
+
715
+ # Generate new story segment
716
+ prompt = (
717
+ f"Continue this story naturally:\n{game_state.history[-2]}\n"
718
+ f"The player ({game_state.player_name} the {game_state.character_class}) chose to: {choice}\n"
719
+ "Write an exciting 50-75 word continuation that builds on this choice."
720
  )
721
 
722
+ new_segment = generate_story_segment(prompt)
723
+ game_state.history.append(new_segment)
724
 
725
+ # Generate new choices
726
+ choices = generate_choices(new_segment)
727
 
728
+ return format_game_output("\n\n".join(game_state.history), choices), *choices
729
 
730
+ def format_game_output(story, choices):
731
+ """Format the game output with story and choices."""
732
+ status_line = (
733
+ f"📜 Adventure of {game_state.player_name} the {game_state.character_class}\n"
734
+ f"❤️ Health: {game_state.health}% | 🎒 Items: {', '.join(game_state.inventory) or 'None'}\n"
735
+ f"{'─' * 50}\n\n"
736
+ )
737
+
738
+ return (
739
+ f"{status_line}{story}\n\n"
740
+ f"What will you do?\n"
741
+ f"1️⃣ {choices[0]}\n"
742
+ f"2️⃣ {choices[1]}\n"
743
+ f"3️⃣ {choices[2]}"
744
+ )
745
 
746
+ def reset_game():
747
+ """Reset the game state."""
748
+ game_state.reset()
749
+ return "Start a new adventure!", "", "", "", None, None, None
 
750
 
751
  # Gradio interface
752
+ with gr.Blocks(title="Story Quest") as demo:
753
+ gr.Markdown("# 🎮 Story Quest")
754
+ gr.Markdown("Embark on an AI-powered interactive adventure! Create your character and shape the story through your choices.")
755
 
756
+ # Character Creation
757
  with gr.Row():
758
+ with gr.Column():
759
+ name_input = gr.Textbox(
760
+ label="Character Name",
761
+ placeholder="Enter your character's name"
762
+ )
763
+ class_input = gr.Dropdown(
764
+ choices=["Warrior", "Mage", "Rogue", "Paladin", "Ranger"],
765
+ label="Character Class"
766
+ )
767
+ start_button = gr.Button("⚔️ Begin Adventure", variant="primary")
768
 
769
+ # Story Display and Choices
770
+ story_output = gr.Textbox(
771
+ label="Your Adventure",
772
+ lines=12,
773
+ interactive=False
774
+ )
775
 
776
  with gr.Row():
777
+ choice_button_1 = gr.Button(visible=False, variant="secondary")
778
+ choice_button_2 = gr.Button(visible=False, variant="secondary")
779
+ choice_button_3 = gr.Button(visible=False, variant="secondary")
780
 
781
+ reset_button = gr.Button("🔄 Reset Adventure", variant="stop")
782
 
783
  # Event handlers
784
  start_button.click(
785
+ fn=initialize_character,
786
+ inputs=[name_input, class_input],
787
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
788
  ).then(
789
  fn=lambda story, c1, c2, c3: (
 
799
  for button in [choice_button_1, choice_button_2, choice_button_3]:
800
  button.click(
801
  fn=continue_story,
802
+ inputs=[button],
803
  outputs=[story_output, choice_button_1, choice_button_2, choice_button_3]
804
  ).then(
805
  fn=lambda story, c1, c2, c3: (
 
813
  )
814
 
815
  reset_button.click(
816
+ fn=reset_game,
817
+ outputs=[
818
+ story_output,
819
+ name_input,
820
+ class_input,
821
+ story_output,
822
+ choice_button_1,
823
+ choice_button_2,
824
+ choice_button_3
825
+ ]
826
  )
827
 
828
+ # Launch the app
829
  if __name__ == "__main__":
830
+ demo.launch()
831
+
832
+
833
+
834
+
835
+