jsakshi commited on
Commit
695e71b
·
verified ·
1 Parent(s): 4566613

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -2
app.py CHANGED
@@ -301,7 +301,7 @@ demo.launch()'''
301
 
302
 
303
 
304
- import gradio as gr
305
  from transformers import AutoModelForCausalLM, AutoTokenizer
306
  import torch
307
  import sys
@@ -455,4 +455,128 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
455
  # Exit handling
456
  exit_check.change(fn=lambda x: sys.exit() if x else None, inputs=[exit_check])
457
 
458
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
 
303
 
304
+ '''import gradio as gr
305
  from transformers import AutoModelForCausalLM, AutoTokenizer
306
  import torch
307
  import sys
 
455
  # Exit handling
456
  exit_check.change(fn=lambda x: sys.exit() if x else None, inputs=[exit_check])
457
 
458
+ demo.launch()'''
459
+
460
+ import gradio as gr
461
+ from transformers import AutoModelForCausalLM, AutoTokenizer
462
+ import torch
463
+ import sys
464
+
465
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
466
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
467
+ model = AutoModelForCausalLM.from_pretrained(
468
+ model_name,
469
+ torch_dtype=torch.float16,
470
+ device_map="auto",
471
+ )
472
+
473
+ def generate_story_continuation(current_story, selected_option=None):
474
+ """
475
+ Generate story continuation based on current story and selected option
476
+ """
477
+ if selected_option:
478
+ prompt = f"""Previous story: {current_story}
479
+ Selected path: {selected_option}
480
+ Continue the story with this choice and provide two new options."""
481
+ else:
482
+ prompt = f"Create a story based on: {current_story}"
483
+
484
+ messages = [{"role": "user", "content": prompt}]
485
+ formatted_prompt = tokenizer.apply_chat_template(
486
+ messages,
487
+ tokenize=False,
488
+ add_generation_prompt=True
489
+ )
490
+
491
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
492
+ outputs = model.generate(
493
+ **inputs,
494
+ max_length=200,
495
+ temperature=0.7,
496
+ do_sample=True,
497
+ pad_token_id=tokenizer.eos_token_id
498
+ )
499
+
500
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
501
+ response = response.split("Assistant: ")[-1].strip()
502
+
503
+ try:
504
+ parts = response.split("Option 1:")
505
+ story = parts[0].strip()
506
+ options_text = "Option 1:" + parts[1]
507
+ option1 = options_text.split("Option 2:")[0].replace("Option 1:", "").strip()
508
+ option2 = options_text.split("Option 2:")[1].strip()
509
+ except:
510
+ story = response
511
+ option1 = "Path A: Continue the adventure"
512
+ option2 = "Path B: Take a different direction"
513
+
514
+ return story, option1, option2
515
+
516
+ def update_story(story, selected_option):
517
+ """
518
+ Update story based on selected option
519
+ """
520
+ new_story, new_option1, new_option2 = generate_story_continuation(story, selected_option)
521
+ full_story = f"{story}\n\nChosen path: {selected_option}\n\n{new_story}"
522
+ return full_story, new_option1, new_option2
523
+
524
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
525
+ gr.Markdown("# 🌟 Interactive Story Generator")
526
+
527
+ with gr.Column():
528
+ story_prompt = gr.Textbox(
529
+ label="Start your story",
530
+ placeholder="Enter your story prompt here...",
531
+ lines=3
532
+ )
533
+
534
+ generate_btn = gr.Button(
535
+ "Generate Story",
536
+ variant="primary",
537
+ scale=1,
538
+ min_width=200,
539
+ size="lg"
540
+ )
541
+
542
+ story_output = gr.Textbox(
543
+ label="Story",
544
+ lines=10,
545
+ show_copy_button=True
546
+ )
547
+
548
+ with gr.Row():
549
+ option1_btn = gr.Button("Option 1")
550
+ option2_btn = gr.Button("Option 2")
551
+
552
+ exit_check = gr.Checkbox(
553
+ label="Exit Story",
554
+ value=False
555
+ )
556
+
557
+ def handle_option_selection(story, option_text):
558
+ return update_story(story, option_text)
559
+
560
+ generate_btn.click(
561
+ fn=generate_story_continuation,
562
+ inputs=[story_prompt],
563
+ outputs=[story_output, option1_btn, option2_btn]
564
+ )
565
+
566
+ option1_btn.click(
567
+ fn=handle_option_selection,
568
+ inputs=[story_output, option1_btn],
569
+ outputs=[story_output, option1_btn, option2_btn]
570
+ )
571
+
572
+ option2_btn.click(
573
+ fn=handle_option_selection,
574
+ inputs=[story_output, option2_btn],
575
+ outputs=[story_output, option1_btn, option2_btn]
576
+ )
577
+
578
+ exit_check.change(fn=lambda x: sys.exit() if x else None, inputs=[exit_check])
579
+
580
+ demo.launch()
581
+
582
+