qver3nc1a commited on
Commit
15d6e65
·
verified ·
1 Parent(s): 84106f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -1,19 +1,21 @@
1
  import gradio as gr
2
  import os
3
- import re
4
  from PIL import Image, ImageDraw
5
- import requests
6
  from io import BytesIO
 
7
  from huggingface_hub import InferenceClient
8
- from dotenv import load_dotenv
 
9
 
10
- load_dotenv()
11
 
12
- hf_token = os.getenv("HF_API_TOKEN")
13
- if not hf_token:
14
- raise ValueError("Set your HF_API_TOKEN environment variable before running.")
 
 
15
 
16
- client = InferenceClient(token=hf_token)
17
 
18
  def screenwriter(prompt: str) -> str:
19
  instructions = f"""
@@ -34,15 +36,18 @@ def screenwriter(prompt: str) -> str:
34
  """
35
 
36
  response = client.text_generation(
37
- model="tiiuae/falcon-7b-instruct",
38
- inputs=instructions,
39
- max_new_tokens=250
 
40
  )
41
- return response[0]['generated_text']
42
 
43
- def remove_think_block(text:str):
 
44
  return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
45
 
 
46
  def parse_screenwriter_output(output: str):
47
  cleaned_output = remove_think_block(output)
48
  delimiter = '---'
@@ -57,34 +62,32 @@ def parse_screenwriter_output(output: str):
57
  character = lines[-1]
58
  return story, character
59
 
 
60
  def error_image(message):
61
  img = Image.new("RGB", (512, 512), color=(255, 255, 255))
62
  d = ImageDraw.Draw(img)
63
  d.text((10, 250), message, fill=(255, 0, 0))
64
  return img
65
 
 
66
  def illustrator(story: str, character: str):
67
  if not story or not character:
68
  raise ValueError('Could not parse story or character from input.')
69
 
70
  scenes = [s.strip() for s in story.split('.') if s.strip()]
71
-
72
  images = []
 
73
  for idx, scene in enumerate(scenes):
74
- prompt = f'Comic book illustration of the scene. No text. Scene: {scene}. Character: {character}'
75
  try:
76
- response = client.text_to_image(
77
- model="stabilityai/stable-diffusion-2",
78
- inputs=prompt
79
- )
80
- image_url = response['generated_image_url']
81
- image = Image.open(BytesIO(requests.get(image_url).content))
82
  images.append((image, scene))
83
  except Exception as e:
84
  images.append((error_image(f'Error: {str(e)}'), f'Error in scene {idx + 1}'))
85
  return images
86
 
87
- def pipeline(prompt: str):
 
88
  output = screenwriter(prompt)
89
  story, character = parse_screenwriter_output(output)
90
  if not story or not character:
@@ -92,20 +95,17 @@ def pipeline(prompt: str):
92
  images = illustrator(story, character)
93
  return f"{story}\n---\n{character}", images
94
 
95
- with gr.Blocks(theme=gr.themes.Ocean(),
96
- title='Comic Generator') as demo:
97
- gr.Markdown(
98
- '''
99
- # Comic Generator
100
- Generates a comic off of your prompt.
101
- ''')
102
  with gr.Row():
103
  story_input = gr.Textbox(label='Story Prompt', placeholder='A unicorn named Jeff discovers a mysterious dish')
104
- generated_story = gr.Button('Generate Story')
105
  with gr.Row():
106
- story_output = gr.Textbox(label='Screenwriter', lines=5)
107
- gallery = gr.Gallery(label='Comic Scenes')
108
- generated_story.click(pipeline, inputs=story_input, outputs=[story_output, gallery])
 
109
 
110
  if __name__ == "__main__":
111
- demo.launch(mcp_server=True)
 
1
  import gradio as gr
2
  import os
 
3
  from PIL import Image, ImageDraw
4
+ import re
5
  from io import BytesIO
6
+
7
  from huggingface_hub import InferenceClient
8
+ from diffusers import StableDiffusionPipeline
9
+ import torch
10
 
11
+ client = InferenceClient()
12
 
13
+ pipe = StableDiffusionPipeline.from_pretrained(
14
+ "CompVis/stable-diffusion-v1-4",
15
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
16
+ )
17
+ pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
18
 
 
19
 
20
  def screenwriter(prompt: str) -> str:
21
  instructions = f"""
 
36
  """
37
 
38
  response = client.text_generation(
39
+ model="tiiuae/falcon-rw-1b", # free-tier supported model
40
+ prompt=instructions,
41
+ max_new_tokens=250,
42
+ temperature=0.7
43
  )
44
+ return response
45
 
46
+
47
+ def remove_think_block(text: str):
48
  return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
49
 
50
+
51
  def parse_screenwriter_output(output: str):
52
  cleaned_output = remove_think_block(output)
53
  delimiter = '---'
 
62
  character = lines[-1]
63
  return story, character
64
 
65
+
66
  def error_image(message):
67
  img = Image.new("RGB", (512, 512), color=(255, 255, 255))
68
  d = ImageDraw.Draw(img)
69
  d.text((10, 250), message, fill=(255, 0, 0))
70
  return img
71
 
72
+
73
  def illustrator(story: str, character: str):
74
  if not story or not character:
75
  raise ValueError('Could not parse story or character from input.')
76
 
77
  scenes = [s.strip() for s in story.split('.') if s.strip()]
 
78
  images = []
79
+
80
  for idx, scene in enumerate(scenes):
81
+ prompt = f"Comic book style illustration. No text. Scene: {scene}. Character: {character}"
82
  try:
83
+ image = pipe(prompt).images[0]
 
 
 
 
 
84
  images.append((image, scene))
85
  except Exception as e:
86
  images.append((error_image(f'Error: {str(e)}'), f'Error in scene {idx + 1}'))
87
  return images
88
 
89
+
90
+ def comic_pipeline(prompt: str):
91
  output = screenwriter(prompt)
92
  story, character = parse_screenwriter_output(output)
93
  if not story or not character:
 
95
  images = illustrator(story, character)
96
  return f"{story}\n---\n{character}", images
97
 
98
+
99
+ with gr.Blocks(theme=gr.themes.Ocean(), title='Comic Generator') as demo:
100
+ gr.Markdown("# Comic Generator\nGive a prompt and get a comic!")
 
 
 
 
101
  with gr.Row():
102
  story_input = gr.Textbox(label='Story Prompt', placeholder='A unicorn named Jeff discovers a mysterious dish')
103
+ generate_btn = gr.Button('Generate Comic')
104
  with gr.Row():
105
+ story_output = gr.Textbox(label='Screenwriter Output', lines=6)
106
+ gallery = gr.Gallery(label='Comic Scenes').style(grid=[2], height='auto')
107
+ generate_btn.click(comic_pipeline, inputs=story_input, outputs=[story_output, gallery])
108
+
109
 
110
  if __name__ == "__main__":
111
+ demo.launch()