File size: 3,842 Bytes
a74d8e8
132c8c4
84106f1
15d6e65
132c8c4
15d6e65
84106f1
15d6e65
 
c725775
15d6e65
a74d8e8
15d6e65
 
 
 
 
c725775
132c8c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84106f1
8f3b5c7
15d6e65
 
8f3b5c7
84106f1
15d6e65
84106f1
15d6e65
 
132c8c4
 
15d6e65
132c8c4
 
 
 
 
 
 
 
 
 
 
 
 
 
15d6e65
132c8c4
 
 
 
 
 
15d6e65
132c8c4
 
 
 
 
 
15d6e65
132c8c4
15d6e65
132c8c4
15d6e65
132c8c4
 
 
 
 
15d6e65
 
132c8c4
 
 
 
 
 
 
15d6e65
 
 
84106f1
 
15d6e65
84106f1
15d6e65
60bd184
15d6e65
 
a74d8e8
 
15d6e65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import os
from PIL import Image, ImageDraw
import re
from io import BytesIO

from huggingface_hub import InferenceClient
from diffusers import StableDiffusionPipeline
import torch

client = InferenceClient()

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")


def screenwriter(prompt: str) -> str:
    instructions = f"""
    You are a skilled comic book writer.

    TASK:
    Generate a short comic book plot based on the story idea provided below. Also generate a description of the main
    character.
    Generate one sentence per scene, separated by periods. The story should be 3-7 sentences long. 
    IMPORTANT: Do NOT include any commentary, notes, or additional thoughts. Only output the story sentences and character description exactly as requested.

    Your output must include:
    - Story plot with one sentence per scene.
    - Very short description of the main character's appearance.
    - IMPORTANT!!! ALWAYS use a delimiter '---' to separate the story from the character description.

    STORY PROMPT: {prompt}
    """

    response = client.text_generation(
        model="mistralai/Mistral-7B-Instruct-v0.3",
        prompt=instructions,
        max_new_tokens=250,
        temperature=0.7,
    )
    return response


def remove_think_block(text: str):
    return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()


def parse_screenwriter_output(output: str):
    cleaned_output = remove_think_block(output)
    delimiter = '---'
    if delimiter in cleaned_output:
        story, character = cleaned_output.split(delimiter, 1)
        return story.strip(), character.strip()
    else:
        lines = [line.strip() for line in cleaned_output.strip().split('\n') if line.strip()]
        if len(lines) < 2:
            return '', ''
        story = ' '.join(lines[:-1])
        character = lines[-1]
        return story, character


def error_image(message):
    img = Image.new("RGB", (512, 512), color=(255, 255, 255))
    d = ImageDraw.Draw(img)
    d.text((10, 250), message, fill=(255, 0, 0))
    return img


def illustrator(story: str, character: str):
    if not story or not character:
        raise ValueError('Could not parse story or character from input.')

    scenes = [s.strip() for s in story.split('.') if s.strip()]
    images = []

    for idx, scene in enumerate(scenes):
        prompt = f"Comic book style illustration. No text. Scene: {scene}. Character: {character}"
        try:
            image = pipe(prompt).images[0]
            images.append((image, scene))
        except Exception as e:
            images.append((error_image(f'Error: {str(e)}'), f'Error in scene {idx + 1}'))
    return images


def comic_pipeline(prompt: str):
    output = screenwriter(prompt)
    story, character = parse_screenwriter_output(output)
    if not story or not character:
        return output, [(error_image("Parse error: Could not extract story or character."), 'Parse error')]
    images = illustrator(story, character)
    return f"{story}\n---\n{character}", images


with gr.Blocks(theme=gr.themes.Ocean(), title='Comic Generator') as demo:
    gr.Markdown("# Comic Generator\nGive a prompt and get a comic!")
    with gr.Row():
        story_input = gr.Textbox(label='Story Prompt', placeholder='A unicorn named Jeff discovers a mysterious dish')
    generate_btn = gr.Button('Generate Comic')
    with gr.Row():
        story_output = gr.Textbox(label='Screenwriter Output', lines=6)
        gallery = gr.Gallery(label='Comic Scenes')
    generate_btn.click(comic_pipeline, inputs=story_input, outputs=[story_output, gallery])


if __name__ == "__main__":
    demo.launch()