Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|
| 8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
st.write(f"Using device: {device}") # Debug message
|
| 10 |
|
| 11 |
-
# Load text model (TinyLlama) with
|
| 12 |
@st.cache_resource
|
| 13 |
def load_text_model():
|
| 14 |
try:
|
|
@@ -24,7 +24,7 @@ def load_text_model():
|
|
| 24 |
|
| 25 |
story_generator = load_text_model()
|
| 26 |
|
| 27 |
-
# Load image model (Stable Diffusion) with
|
| 28 |
@st.cache_resource
|
| 29 |
def load_image_model():
|
| 30 |
try:
|
|
@@ -50,11 +50,12 @@ def generate_story(prompt):
|
|
| 50 |
st.write("⏳ Generating story...")
|
| 51 |
story_output = story_generator(
|
| 52 |
formatted_prompt,
|
| 53 |
-
max_length=
|
| 54 |
do_sample=True,
|
| 55 |
temperature=0.7,
|
| 56 |
top_k=30,
|
| 57 |
-
num_return_sequences=1
|
|
|
|
| 58 |
)[0]['generated_text']
|
| 59 |
st.write("✅ Story generated successfully!")
|
| 60 |
return story_output.replace(formatted_prompt, "").strip()
|
|
@@ -102,11 +103,10 @@ if user_prompt:
|
|
| 102 |
else:
|
| 103 |
with st.spinner("⏳ Generating image..."):
|
| 104 |
try:
|
| 105 |
-
image = image_generator(user_prompt, num_inference_steps=
|
| 106 |
-
image = image.resize((512, 512)) # Resize to smaller 512x512
|
| 107 |
st.write("✅ Image generated successfully!")
|
| 108 |
|
| 109 |
-
# Extract first sentence (
|
| 110 |
speech_text = generated_story.split(".")[0][:50]
|
| 111 |
image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
|
| 112 |
|
|
|
|
| 8 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
st.write(f"Using device: {device}") # Debug message
|
| 10 |
|
| 11 |
+
# Load text model (TinyLlama) with error handling
|
| 12 |
@st.cache_resource
|
| 13 |
def load_text_model():
|
| 14 |
try:
|
|
|
|
| 24 |
|
| 25 |
story_generator = load_text_model()
|
| 26 |
|
| 27 |
+
# Load image model (Stable Diffusion) with error handling
|
| 28 |
@st.cache_resource
|
| 29 |
def load_image_model():
|
| 30 |
try:
|
|
|
|
| 50 |
st.write("⏳ Generating story...")
|
| 51 |
story_output = story_generator(
|
| 52 |
formatted_prompt,
|
| 53 |
+
max_length=150, # Shorter length for efficiency
|
| 54 |
do_sample=True,
|
| 55 |
temperature=0.7,
|
| 56 |
top_k=30,
|
| 57 |
+
num_return_sequences=1,
|
| 58 |
+
truncation = True
|
| 59 |
)[0]['generated_text']
|
| 60 |
st.write("✅ Story generated successfully!")
|
| 61 |
return story_output.replace(formatted_prompt, "").strip()
|
|
|
|
| 103 |
else:
|
| 104 |
with st.spinner("⏳ Generating image..."):
|
| 105 |
try:
|
| 106 |
+
image = image_generator(user_prompt, num_inference_steps=12).images[0]
|
|
|
|
| 107 |
st.write("✅ Image generated successfully!")
|
| 108 |
|
| 109 |
+
# Extract first sentence (50 characters max) for speech bubble
|
| 110 |
speech_text = generated_story.split(".")[0][:50]
|
| 111 |
image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))
|
| 112 |
|