File size: 2,978 Bytes
91e4459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
import torch
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
from PIL import Image, ImageDraw, ImageFont

# Check if CUDA is available for GPU acceleration
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the text generation model (TinyLlama)
@st.cache_resource
def load_text_model():
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    return pipeline("text-generation", model=model, tokenizer=tokenizer)

story_generator = load_text_model()

# Load the image generation model (Stable Diffusion Turbo)
@st.cache_resource
def load_image_model():
    model_id = "runwayml/stable-diffusion-v1-5"
    return StableDiffusionPipeline.from_pretrained(model_id).to(device)

image_generator = load_image_model()

# Function to generate a short story
def generate_story(prompt):
    formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:"
    story_output = story_generator(
        formatted_prompt,
        max_length=250,  # Short story length
        do_sample=True,
        temperature=0.7,
        top_k=50,
        num_return_sequences=1,
        truncation = True
    )[0]['generated_text']

    return story_output.replace(formatted_prompt, "").strip()

# Function to add a speech bubble to the image
def add_speech_bubble(image, text, position=(50, 50)):
    draw = ImageDraw.Draw(image)

    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except IOError:
        font = ImageFont.load_default()

    text_bbox = draw.textbbox((0, 0), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    bubble_width, bubble_height = text_width + 30, text_height + 20
    bubble_x, bubble_y = position

    draw.ellipse([bubble_x, bubble_y, bubble_x + bubble_width, bubble_y + bubble_height], fill="white", outline="black")
    draw.text((bubble_x + 15, bubble_y + 10), text, font=font, fill="black")

    return image

# Streamlit UI
st.title("🦸‍♂️ AI Comic Story Generator")
st.write("Enter a prompt to generate a comic-style story and image!")

# User input
user_prompt = st.text_input("📝 Enter your story prompt:")

if user_prompt:
    st.subheader("📖 AI-Generated Story")
    generated_story = generate_story(user_prompt)
    st.write(generated_story)

    st.subheader("🖼️ AI-Generated Image")
    with st.spinner("Generating image..."):
        image = image_generator(user_prompt, num_inference_steps=30).images[0]

    speech_text = generated_story.split(".")[0][:50]
    image_with_bubble = add_speech_bubble(image, speech_text, position=(50, 50))

    st.image(image_with_bubble, caption="Generated Comic Image", use_container_width=True)