comic_generator / app.py
qver3nc1a's picture
Update app.py
8f3b5c7 verified
import gradio as gr
import os
from PIL import Image, ImageDraw
import re
from io import BytesIO
from huggingface_hub import InferenceClient
from diffusers import StableDiffusionPipeline
import torch
client = InferenceClient()
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
def screenwriter(prompt: str) -> str:
instructions = f"""
You are a skilled comic book writer.
TASK:
Generate a short comic book plot based on the story idea provided below. Also generate a description of the main
character.
Generate one sentence per scene, separated by periods. The story should be 3-7 sentences long.
IMPORTANT: Do NOT include any commentary, notes, or additional thoughts. Only output the story sentences and character description exactly as requested.
Your output must include:
- Story plot with one sentence per scene.
- Very short description of the main character's appearance.
- IMPORTANT!!! ALWAYS use a delimiter '---' to separate the story from the character description.
STORY PROMPT: {prompt}
"""
response = client.text_generation(
model="mistralai/Mistral-7B-Instruct-v0.3",
prompt=instructions,
max_new_tokens=250,
temperature=0.7,
)
return response
def remove_think_block(text: str):
return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
def parse_screenwriter_output(output: str):
cleaned_output = remove_think_block(output)
delimiter = '---'
if delimiter in cleaned_output:
story, character = cleaned_output.split(delimiter, 1)
return story.strip(), character.strip()
else:
lines = [line.strip() for line in cleaned_output.strip().split('\n') if line.strip()]
if len(lines) < 2:
return '', ''
story = ' '.join(lines[:-1])
character = lines[-1]
return story, character
def error_image(message):
img = Image.new("RGB", (512, 512), color=(255, 255, 255))
d = ImageDraw.Draw(img)
d.text((10, 250), message, fill=(255, 0, 0))
return img
def illustrator(story: str, character: str):
if not story or not character:
raise ValueError('Could not parse story or character from input.')
scenes = [s.strip() for s in story.split('.') if s.strip()]
images = []
for idx, scene in enumerate(scenes):
prompt = f"Comic book style illustration. No text. Scene: {scene}. Character: {character}"
try:
image = pipe(prompt).images[0]
images.append((image, scene))
except Exception as e:
images.append((error_image(f'Error: {str(e)}'), f'Error in scene {idx + 1}'))
return images
def comic_pipeline(prompt: str):
output = screenwriter(prompt)
story, character = parse_screenwriter_output(output)
if not story or not character:
return output, [(error_image("Parse error: Could not extract story or character."), 'Parse error')]
images = illustrator(story, character)
return f"{story}\n---\n{character}", images
with gr.Blocks(theme=gr.themes.Ocean(), title='Comic Generator') as demo:
gr.Markdown("# Comic Generator\nGive a prompt and get a comic!")
with gr.Row():
story_input = gr.Textbox(label='Story Prompt', placeholder='A unicorn named Jeff discovers a mysterious dish')
generate_btn = gr.Button('Generate Comic')
with gr.Row():
story_output = gr.Textbox(label='Screenwriter Output', lines=6)
gallery = gr.Gallery(label='Comic Scenes')
generate_btn.click(comic_pipeline, inputs=story_input, outputs=[story_output, gallery])
if __name__ == "__main__":
demo.launch()