bh4vay commited on
Commit
7e60bc2
·
verified ·
1 Parent(s): 63b8570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -29
app.py CHANGED
@@ -4,43 +4,65 @@ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  from diffusers import StableDiffusionPipeline
5
  from PIL import Image, ImageDraw, ImageFont
6
 
7
- # Check if CUDA is available for GPU acceleration
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
- # Load the text generation model (TinyLlama)
11
  @st.cache_resource
12
  def load_text_model():
13
- st.write("Loading text model...") # Debug message
14
- model_name = "distilgpt2"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
17
- st.write("Model loaded!") # Confirm model is ready
18
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
19
 
20
  story_generator = load_text_model()
21
 
22
- # Load the image generation model (Stable Diffusion Turbo)
23
  @st.cache_resource
24
  def load_image_model():
25
- model_id = "runwayml/stable-diffusion-v1-5"
26
- return StableDiffusionPipeline.from_pretrained(model_id).to(device)
 
 
 
 
 
 
 
27
 
28
  image_generator = load_image_model()
29
 
30
  # Function to generate a short story
31
  def generate_story(prompt):
 
 
 
32
  formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
33
- story_output = story_generator(
34
- formatted_prompt,
35
- max_length=150, # Reduce text generation length
36
- do_sample=True,
37
- temperature=0.7,
38
- top_k=30, # Lower value for efficiency
39
- num_return_sequences=1
40
- )[0]['generated_text']
41
- return story_output.replace(formatted_prompt, "").strip()
42
-
43
- # Function to add a speech bubble to the image
 
 
 
 
 
 
 
44
  def add_speech_bubble(image, text, position=(50, 50)):
45
  draw = ImageDraw.Draw(image)
46
 
@@ -74,10 +96,19 @@ if user_prompt:
74
  st.write(generated_story)
75
 
76
  st.subheader("🖼️ AI-Generated Image")
77
- with st.spinner("Generating image..."):
78
- image = image_generator(user_prompt, num_inference_steps=30).images[0]
79
-
80
- speech_text = generated_story.split(".")[0][:50]
81
- image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
82
-
83
- st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)
 
 
 
 
 
 
 
 
 
 
4
  from diffusers import StableDiffusionPipeline
5
  from PIL import Image, ImageDraw, ImageFont
6
 
7
+ # Check for GPU availability
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ st.write(f"Using device: {device}") # Debug message
10
 
11
+ # Load text model (TinyLlama) with error handling
12
  @st.cache_resource
13
  def load_text_model():
14
+ try:
15
+ st.write("⏳ Loading text model...")
16
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
19
+ st.write(" Text model loaded successfully!")
20
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
21
+ except Exception as e:
22
+ st.error(f"❌ Error loading text model: {e}")
23
+ return None
24
 
25
  story_generator = load_text_model()
26
 
27
+ # Load image model (Stable Diffusion) with error handling
28
  @st.cache_resource
29
  def load_image_model():
30
+ try:
31
+ st.write("⏳ Loading image model...")
32
+ model_id = "runwayml/stable-diffusion-v1-5"
33
+ model = StableDiffusionPipeline.from_pretrained(model_id).to(device)
34
+ st.write("✅ Image model loaded successfully!")
35
+ return model
36
+ except Exception as e:
37
+ st.error(f"❌ Error loading image model: {e}")
38
+ return None
39
 
40
  image_generator = load_image_model()
41
 
42
  # Function to generate a short story
43
  def generate_story(prompt):
44
+ if not story_generator:
45
+ return "❌ Error: Story model not loaded."
46
+
47
  formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
48
+
49
+ try:
50
+ st.write("⏳ Generating story...")
51
+ story_output = story_generator(
52
+ formatted_prompt,
53
+ max_length=150, # Shorter length for efficiency
54
+ do_sample=True,
55
+ temperature=0.7,
56
+ top_k=30,
57
+ num_return_sequences=1
58
+ )[0]['generated_text']
59
+ st.write("✅ Story generated successfully!")
60
+ return story_output.replace(formatted_prompt, "").strip()
61
+ except Exception as e:
62
+ st.error(f"❌ Error generating story: {e}")
63
+ return "Error generating story."
64
+
65
+ # Function to add a speech bubble to an image
66
  def add_speech_bubble(image, text, position=(50, 50)):
67
  draw = ImageDraw.Draw(image)
68
 
 
96
  st.write(generated_story)
97
 
98
  st.subheader("🖼️ AI-Generated Image")
99
+
100
+ if not image_generator:
101
+ st.error("❌ Error: Image model not loaded.")
102
+ else:
103
+ with st.spinner("⏳ Generating image..."):
104
+ try:
105
+ image = image_generator(user_prompt, num_inference_steps=30).images[0]
106
+ st.write("✅ Image generated successfully!")
107
+
108
+ # Extract first sentence (50 characters max) for speech bubble
109
+ speech_text = generated_story.split(".")[0][:50]
110
+ image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
111
+
112
+ st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)
113
+ except Exception as e:
114
+ st.error(f"❌ Error generating image: {e}")