bh4vay commited on
Commit
eb07b95
Β·
verified Β·
1 Parent(s): 00d6cb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -27
app.py CHANGED
@@ -1,48 +1,115 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import pipeline
4
  from diffusers import StableDiffusionPipeline
5
- from PIL import Image
6
 
7
- # Set device (CPU or GPU)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
- # Load text generation model
11
  @st.cache_resource
12
  def load_text_model():
13
- return pipeline("text-generation", model="gpt2")
 
 
 
 
 
 
 
 
 
14
 
15
- text_generator = load_text_model()
16
 
17
- # Load image generation model
18
  @st.cache_resource
19
  def load_image_model():
20
- pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
21
- pipe.to(device)
22
- pipe.to(torch.float16) # Use float16 for speed
23
- if device == "cuda":
24
- pipe = torch.compile(pipe) # Optimize for GPU
25
- return pipe
 
 
 
26
 
27
  image_generator = load_image_model()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Streamlit UI
30
- st.title("πŸ€– AI Comic Story Generator")
31
  st.write("Enter a prompt to generate a comic-style story and image!")
32
 
33
- # Input field for story prompt
34
- story_prompt = st.text_input("πŸ“ Enter your story prompt:", "")
 
 
 
 
 
35
 
36
- if story_prompt:
37
- with st.spinner("⏳ Generating story..."):
38
- story = text_generator(story_prompt, max_length=100, num_return_sequences=1)[0]["generated_text"]
39
- st.success("βœ… Story generated successfully!")
40
- st.write(story)
 
 
 
 
 
41
 
42
- # Generate image
43
- with st.spinner("🎨 Generating image..."):
44
- image = image_generator(story_prompt, num_inference_steps=15).images[0] # Reduced from 30 β†’ 15
45
- image = image.resize((512, 512)) # Resize to 512x512 to make it smaller
46
- st.image(image, caption="πŸ–ΌοΏ½οΏ½οΏ½ AI-Generated Comic Image", use_column_width=False)
47
 
48
- st.write("πŸš€ Optimized for speed & performance!")
 
 
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  from diffusers import StableDiffusionPipeline
5
+ from PIL import Image, ImageDraw, ImageFont
6
 
7
+ # Check for GPU availability
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ st.write(f"Using device: {device}") # Debug message
10
 
11
+ # Load text model (TinyLlama) with caching
12
  @st.cache_resource
13
  def load_text_model():
14
+ try:
15
+ st.write("⏳ Loading text model...")
16
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
19
+ st.write("βœ… Text model loaded successfully!")
20
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
21
+ except Exception as e:
22
+ st.error(f"❌ Error loading text model: {e}")
23
+ return None
24
 
25
+ story_generator = load_text_model()
26
 
27
+ # Load image model (Stable Diffusion) with caching
28
  @st.cache_resource
29
  def load_image_model():
30
+ try:
31
+ st.write("⏳ Loading image model...")
32
+ model_id = "runwayml/stable-diffusion-v1-5"
33
+ model = StableDiffusionPipeline.from_pretrained(model_id).to(device)
34
+ st.write("βœ… Image model loaded successfully!")
35
+ return model
36
+ except Exception as e:
37
+ st.error(f"❌ Error loading image model: {e}")
38
+ return None
39
 
40
  image_generator = load_image_model()
41
 
42
+ # Function to generate a short story
43
+ def generate_story(prompt):
44
+ if not story_generator:
45
+ return "❌ Error: Story model not loaded."
46
+
47
+ formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
48
+
49
+ try:
50
+ st.write("⏳ Generating story...")
51
+ story_output = story_generator(
52
+ formatted_prompt,
53
+ max_length=120, # Reduced length for efficiency
54
+ do_sample=True,
55
+ temperature=0.7,
56
+ top_k=30,
57
+ num_return_sequences=1
58
+ )[0]['generated_text']
59
+ st.write("βœ… Story generated successfully!")
60
+ return story_output.replace(formatted_prompt, "").strip()
61
+ except Exception as e:
62
+ st.error(f"❌ Error generating story: {e}")
63
+ return "Error generating story."
64
+
65
+ # Function to add a speech bubble to an image
66
+ def add_speech_bubble(image, text, position=(50, 50)):
67
+ draw = ImageDraw.Draw(image)
68
+
69
+ try:
70
+ font = ImageFont.truetype("arial.ttf", 20)
71
+ except IOError:
72
+ font = ImageFont.load_default()
73
+
74
+ text_bbox = draw.textbbox((0, 0), text, font=font)
75
+ text_width = text_bbox[2] - text_bbox[0]
76
+ text_height = text_bbox[3] - text_bbox[1]
77
+
78
+ bubble_width, bubble_height = text_width + 30, text_height + 20
79
+ bubble_x, bubble_y = position
80
+
81
+ draw.ellipse([bubble_x, bubble_y, bubble_x + bubble_width, bubble_y + bubble_height], fill="white", outline="black")
82
+ draw.text((bubble_x + 15, bubble_y + 10), text, font=font, fill="black")
83
+
84
+ return image
85
+
86
  # Streamlit UI
87
+ st.title("πŸ¦Έβ€β™‚οΈ AI Comic Story Generator")
88
  st.write("Enter a prompt to generate a comic-style story and image!")
89
 
90
+ # User input
91
+ user_prompt = st.text_input("πŸ“ Enter your story prompt:")
92
+
93
+ if user_prompt:
94
+ st.subheader("πŸ“– AI-Generated Story")
95
+ generated_story = generate_story(user_prompt)
96
+ st.write(generated_story)
97
 
98
+ st.subheader("πŸ–ΌοΈ AI-Generated Image")
99
+
100
+ if not image_generator:
101
+ st.error("❌ Error: Image model not loaded.")
102
+ else:
103
+ with st.spinner("⏳ Generating image..."):
104
+ try:
105
+ image = image_generator(user_prompt, num_inference_steps=10).images[0] # Reduced steps from 12 β†’ 10
106
+ image = image.resize((512, 512)) # Resize to smaller 512x512
107
+ st.write("βœ… Image generated successfully!")
108
 
109
+ # Extract first sentence (max 50 characters) for speech bubble
110
+ speech_text = generated_story.split(".")[0][:50]
111
+ image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
 
 
112
 
113
+ st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)
114
+ except Exception as e:
115
+ st.error(f"❌ Error generating image: {e}")