bh4vay commited on
Commit
00d6cb1
Β·
verified Β·
1 Parent(s): 7ed1eb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -93
app.py CHANGED
@@ -1,114 +1,48 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  from diffusers import StableDiffusionPipeline
5
- from PIL import Image, ImageDraw, ImageFont
6
 
7
- # Check for GPU availability
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- st.write(f"Using device: {device}") # Debug message
10
 
11
- # Load text model (TinyLlama) with error handling
12
  @st.cache_resource
13
  def load_text_model():
14
- try:
15
- st.write("⏳ Loading text model...")
16
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
17
- tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
19
- st.write("βœ… Text model loaded successfully!")
20
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
21
- except Exception as e:
22
- st.error(f"❌ Error loading text model: {e}")
23
- return None
24
 
25
- story_generator = load_text_model()
26
 
27
- # Load image model (Stable Diffusion) with error handling
28
  @st.cache_resource
29
  def load_image_model():
30
- try:
31
- st.write("⏳ Loading image model...")
32
- model_id = "runwayml/stable-diffusion-v1-5"
33
- model = StableDiffusionPipeline.from_pretrained(model_id).to(device)
34
- st.write("βœ… Image model loaded successfully!")
35
- return model
36
- except Exception as e:
37
- st.error(f"❌ Error loading image model: {e}")
38
- return None
39
 
40
  image_generator = load_image_model()
41
 
42
- # Function to generate a short story
43
- def generate_story(prompt):
44
- if not story_generator:
45
- return "❌ Error: Story model not loaded."
46
-
47
- formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
48
-
49
- try:
50
- st.write("⏳ Generating story...")
51
- story_output = story_generator(
52
- formatted_prompt,
53
- max_length=150, # Shorter length for efficiency
54
- do_sample=True,
55
- temperature=0.7,
56
- top_k=30,
57
- num_return_sequences=1
58
- )[0]['generated_text']
59
- st.write("βœ… Story generated successfully!")
60
- return story_output.replace(formatted_prompt, "").strip()
61
- except Exception as e:
62
- st.error(f"❌ Error generating story: {e}")
63
- return "Error generating story."
64
-
65
- # Function to add a speech bubble to an image
66
- def add_speech_bubble(image, text, position=(50, 50)):
67
- draw = ImageDraw.Draw(image)
68
-
69
- try:
70
- font = ImageFont.truetype("arial.ttf", 20)
71
- except IOError:
72
- font = ImageFont.load_default()
73
-
74
- text_bbox = draw.textbbox((0, 0), text, font=font)
75
- text_width = text_bbox[2] - text_bbox[0]
76
- text_height = text_bbox[3] - text_bbox[1]
77
-
78
- bubble_width, bubble_height = text_width + 30, text_height + 20
79
- bubble_x, bubble_y = position
80
-
81
- draw.ellipse([bubble_x, bubble_y, bubble_x + bubble_width, bubble_y + bubble_height], fill="white", outline="black")
82
- draw.text((bubble_x + 15, bubble_y + 10), text, font=font, fill="black")
83
-
84
- return image
85
-
86
  # Streamlit UI
87
- st.title("πŸ¦Έβ€β™‚οΈ AI Comic Story Generator")
88
  st.write("Enter a prompt to generate a comic-style story and image!")
89
 
90
- # User input
91
- user_prompt = st.text_input("πŸ“ Enter your story prompt:")
92
-
93
- if user_prompt:
94
- st.subheader("πŸ“– AI-Generated Story")
95
- generated_story = generate_story(user_prompt)
96
- st.write(generated_story)
97
 
98
- st.subheader("πŸ–ΌοΈ AI-Generated Image")
99
-
100
- if not image_generator:
101
- st.error("❌ Error: Image model not loaded.")
102
- else:
103
- with st.spinner("⏳ Generating image..."):
104
- try:
105
- image = image_generator(user_prompt, num_inference_steps=30).images[0]
106
- st.write("βœ… Image generated successfully!")
107
 
108
- # Extract first sentence (50 characters max) for speech bubble
109
- speech_text = generated_story.split(".")[0][:50]
110
- image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
 
 
111
 
112
- st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)
113
- except Exception as e:
114
- st.error(f"❌ Error generating image: {e}")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import pipeline
4
  from diffusers import StableDiffusionPipeline
5
+ from PIL import Image
6
 
7
+ # Set device (CPU or GPU)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
+ # Load text generation model
11
  @st.cache_resource
12
  def load_text_model():
13
+ return pipeline("text-generation", model="gpt2")
 
 
 
 
 
 
 
 
 
14
 
15
+ text_generator = load_text_model()
16
 
17
+ # Load image generation model
18
  @st.cache_resource
19
  def load_image_model():
20
+ pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
21
+ pipe.to(device)
22
+ pipe.to(torch.float16) # Use float16 for speed
23
+ if device == "cuda":
24
+ pipe = torch.compile(pipe) # Optimize for GPU
25
+ return pipe
 
 
 
26
 
27
  image_generator = load_image_model()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Streamlit UI
30
+ st.title("πŸ€– AI Comic Story Generator")
31
  st.write("Enter a prompt to generate a comic-style story and image!")
32
 
33
+ # Input field for story prompt
34
+ story_prompt = st.text_input("πŸ“ Enter your story prompt:", "")
 
 
 
 
 
35
 
36
+ if story_prompt:
37
+ with st.spinner("⏳ Generating story..."):
38
+ story = text_generator(story_prompt, max_length=100, num_return_sequences=1)[0]["generated_text"]
39
+ st.success("βœ… Story generated successfully!")
40
+ st.write(story)
 
 
 
 
41
 
42
+ # Generate image
43
+ with st.spinner("🎨 Generating image..."):
44
+ image = image_generator(story_prompt, num_inference_steps=15).images[0] # Reduced from 30 β†’ 15
45
+ image = image.resize((512, 512)) # Resize to 512x512 to make it smaller
46
+ st.image(image, caption="πŸ–ΌοΈ AI-Generated Comic Image", use_column_width=False)
47
 
48
+ st.write("πŸš€ Optimized for speed & performance!")