8BitStudio commited on
Commit
d12fd5f
·
verified ·
1 Parent(s): 7573aab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -35
app.py CHANGED
@@ -1,49 +1,206 @@
1
- import gradio as gr
2
  import torch
3
- from huggingface_hub import hf_hub_download
4
- from diffusers import DDIMScheduler, DPMSolverMultistepScheduler
5
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # --- Load your pipeline (adjust imports to match your generate_hf.py) ---
8
- # This assumes your generate_hf.py defines a pipeline or model class
9
- import sys
10
- sys.path.append(".")
11
- from generate_hf import load_pipeline # adjust to whatever your file exposes
 
 
 
 
 
 
 
12
 
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- pipe = load_pipeline("8BitStudio/Aniimage-1", device=device)
 
15
 
16
- DEFAULT_NEGATIVE = "low quality, ugly, blurry, distorted, deformed, bad anatomy, bad proportions, extra limbs, missing limbs, watermark, text, signature, washed out, flat colors, manga panel, disfigured, poorly drawn, jpeg artifacts, cropped, out of frame"
 
 
 
17
 
18
- def generate(prompt, negative_prompt, steps, cfg, scheduler):
19
- if scheduler == "DPM++ 2M":
20
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
21
- else:
22
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
23
 
24
- image = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  prompt=prompt,
26
- negative_prompt=negative_prompt,
27
- num_inference_steps=steps,
28
- guidance_scale=cfg,
29
- ).images[0]
30
- return image
 
 
31
 
32
  with gr.Blocks(title="Aniimage-1 by 8BitStudio") as demo:
33
- gr.Markdown("# 🎨 Aniimage-1\nAnime image generation model by 8BitStudio. Use plain English prompts.")
34
-
35
  with gr.Row():
36
- with gr.Column():
37
- prompt = gr.Textbox(label="Prompt", placeholder="A smiling anime girl with red hair and a school uniform")
38
- negative = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE)
 
39
  with gr.Row():
40
  steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
41
- cfg = gr.Slider(1, 15, value=7.5, step=0.5, label="CFG Scale")
42
- scheduler = gr.Radio(["DPM++ 2M", "DDIM"], value="DPM++ 2M", label="Scheduler")
43
- btn = gr.Button("Generate", variant="primary")
44
- with gr.Column():
45
- output = gr.Image(label="Generated Image")
46
-
47
- btn.click(generate, inputs=[prompt, negative, steps, cfg, scheduler], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  demo.launch()
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from PIL import Image, ImageEnhance, ImageFilter
6
+ import gradio as gr
7
+
8
+ # ── Config ────────────────────────────────────────────────────────────────────
9
+ HF_REPO_ID = "8BitStudio/Aniimage-1"
10
+ VAE_ID = "stabilityai/sd-vae-ft-mse"
11
+ CLIP_ID = "openai/clip-vit-large-patch14"
12
+
13
+ UNET_CONFIG = dict(
14
+ sample_size=32,
15
+ in_channels=4,
16
+ out_channels=4,
17
+ block_out_channels=(256, 512, 768, 1024),
18
+ layers_per_block=2,
19
+ cross_attention_dim=768,
20
+ attention_head_dim=8,
21
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D",
22
+ "CrossAttnDownBlock2D", "DownBlock2D"),
23
+ up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
24
+ "CrossAttnUpBlock2D", "UpBlock2D"),
25
+ )
26
+
27
+ DEFAULT_NEGATIVE = (
28
+ "low quality, ugly, blurry, distorted, deformed, bad anatomy, "
29
+ "bad proportions, extra limbs, missing limbs, watermark, text, "
30
+ "signature, washed out, flat colors, manga panel, disfigured, "
31
+ "poorly drawn, jpeg artifacts, cropped, out of frame"
32
+ )
33
+
34
+ SCHEDULER_LIST = ["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler a", "Euler", "DDIM"]
35
+
36
+ # ── Generator ─────────────────────────────────────────────────────────────────
37
+ class Generator:
38
+ def __init__(self):
39
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ self.vae = None
41
+ self.text_encoder = None
42
+ self.tokenizer = None
43
+ self.unet = None
44
+ self.scheduler_name = "DPM++ 2M Karras"
45
+ self.latent_size = 32
46
+ self.output_size = 256
47
+
48
+ def load(self):
49
+ if self.unet is not None:
50
+ return
51
+ from diffusers import AutoencoderKL, UNet2DConditionModel
52
+ from transformers import CLIPTextModel, CLIPTokenizer
53
+ from huggingface_hub import hf_hub_download
54
+ from safetensors.torch import load_file
55
+ import shutil
56
+
57
+ print("Loading VAE...")
58
+ self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device)
59
+ self.vae.eval()
60
+
61
+ print("Loading CLIP...")
62
+ self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
63
+ self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device)
64
+ self.text_encoder.eval()
65
+
66
+ print("Loading UNet...")
67
+ weights_path = Path("unet_weights.safetensors")
68
+ if not weights_path.exists():
69
+ dl = hf_hub_download(repo_id=HF_REPO_ID,
70
+ filename="diffusion_pytorch_model.safetensors")
71
+ shutil.copy2(dl, weights_path)
72
+
73
+ self.unet = UNet2DConditionModel(**UNET_CONFIG).to(self.device)
74
+ state = load_file(str(weights_path), device=str(self.device))
75
+ self.unet.load_state_dict(state)
76
+ self.unet.eval()
77
+ print(f"Ready! Running on {self.device.upper()}")
78
+
79
+ def _make_scheduler(self, name):
80
+ from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler,
81
+ EulerAncestralDiscreteScheduler,
82
+ EulerDiscreteScheduler)
83
+ base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear",
84
+ prediction_type="epsilon")
85
+ if name == "DPM++ 2M Karras":
86
+ return DPMSolverMultistepScheduler(
87
+ **base, algorithm_type="dpmsolver++",
88
+ solver_order=2, use_karras_sigmas=True)
89
+ elif name == "DPM++ SDE Karras":
90
+ return DPMSolverMultistepScheduler(
91
+ **base, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
92
+ elif name == "Euler a":
93
+ return EulerAncestralDiscreteScheduler(**base)
94
+ elif name == "Euler":
95
+ return EulerDiscreteScheduler(**base)
96
+ else:
97
+ return DDIMScheduler(**base, clip_sample=False, set_alpha_to_one=False)
98
 
99
+ def _decode_latents(self, latents):
100
+ scaled = latents / self.vae.config.scaling_factor
101
+ with torch.no_grad():
102
+ image = self.vae.decode(scaled.float()).sample
103
+ image = (image.float() / 2 + 0.5).clamp(0, 1)
104
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
105
+ image = (image * 255).round().astype("uint8")
106
+ img = Image.fromarray(image)
107
+ img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2))
108
+ img = ImageEnhance.Contrast(img).enhance(1.06)
109
+ img = ImageEnhance.Color(img).enhance(1.10)
110
+ return img
111
 
112
+ def _sharpen_latents(self, latents, amount=0.08):
113
+ blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1)
114
+ return latents + amount * (latents - blurred)
115
 
116
+ @torch.no_grad()
117
+ def generate(self, prompt, negative_prompt="", steps=25,
118
+ guidance_scale=7.5, seed=-1, scheduler_name="DPM++ 2M Karras"):
119
+ self.load()
120
 
121
+ if seed < 0:
122
+ seed = torch.randint(0, 2**32, (1,)).item()
123
+ gen = torch.Generator(device=self.device).manual_seed(seed)
 
 
124
 
125
+ tok = self.tokenizer(prompt, padding="max_length",
126
+ max_length=self.tokenizer.model_max_length,
127
+ truncation=True, return_tensors="pt")
128
+ text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
129
+
130
+ tok_neg = self.tokenizer(negative_prompt, padding="max_length",
131
+ max_length=self.tokenizer.model_max_length,
132
+ truncation=True, return_tensors="pt")
133
+ neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
134
+
135
+ combined = torch.cat([neg_emb, text_emb])
136
+ scheduler = self._make_scheduler(scheduler_name)
137
+ scheduler.set_timesteps(steps, device=self.device)
138
+
139
+ latents = torch.randn(1, 4, self.latent_size, self.latent_size,
140
+ generator=gen, device=self.device)
141
+ latents = latents * scheduler.init_noise_sigma
142
+
143
+ for t in scheduler.timesteps:
144
+ inp = torch.cat([latents] * 2)
145
+ inp = scheduler.scale_model_input(inp, t)
146
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
147
+ enabled=(self.device == "cuda")):
148
+ pred = self.unet(inp, t, encoder_hidden_states=combined).sample
149
+ pred_neg, pred_text = pred.chunk(2)
150
+ pred = pred_neg + guidance_scale * (pred_text - pred_neg)
151
+ latents = scheduler.step(pred, t, latents).prev_sample
152
+
153
+ latents = self._sharpen_latents(latents)
154
+ return self._decode_latents(latents), seed
155
+
156
+
157
+ # ── Load model once at startup ────────────────────────────────────────────────
158
+ gen = Generator()
159
+
160
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
161
+ def run(prompt, negative, steps, cfg, scheduler, seed):
162
+ if not prompt.strip():
163
+ return None, "Please enter a prompt!"
164
+ image, used_seed = gen.generate(
165
  prompt=prompt,
166
+ negative_prompt=negative,
167
+ steps=int(steps),
168
+ guidance_scale=float(cfg),
169
+ seed=int(seed),
170
+ scheduler_name=scheduler,
171
+ )
172
+ return image, f"Seed: {used_seed}"
173
 
174
  with gr.Blocks(title="Aniimage-1 by 8BitStudio") as demo:
175
+ gr.Markdown("# 🎨 Aniimage-1\nAnime image generator by **8BitStudio** · 256×256 · Trained from scratch on 830k Danbooru images\n\nUse plain English: *\"A smiling anime girl with red hair and a school uniform\"*")
176
+
177
  with gr.Row():
178
+ with gr.Column(scale=1):
179
+ prompt = gr.Textbox(label="Prompt", lines=3,
180
+ placeholder="A smiling anime girl with red hair and a school uniform")
181
+ negative = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE, lines=2)
182
  with gr.Row():
183
  steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
184
+ cfg = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale")
185
+ with gr.Row():
186
+ scheduler = gr.Dropdown(SCHEDULER_LIST, value="DPM++ 2M Karras", label="Scheduler")
187
+ seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
188
+ btn = gr.Button(" Generate", variant="primary")
189
+
190
+ with gr.Column(scale=1):
191
+ output = gr.Image(label="Generated Image", type="pil")
192
+ seed_out = gr.Textbox(label="Used Seed", interactive=False)
193
+
194
+ btn.click(run, inputs=[prompt, negative, steps, cfg, scheduler, seed],
195
+ outputs=[output, seed_out])
196
+
197
+ gr.Examples(
198
+ examples=[
199
+ ["A smiling anime girl with red hair and a school uniform", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1],
200
+ ["A mysterious anime girl with silver hair under a night sky with stars", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1],
201
+ ["An anime girl in a maid dress holding a teacup, cherry blossoms in the background", DEFAULT_NEGATIVE, 30, 7.5, "DPM++ 2M Karras", -1],
202
+ ],
203
+ inputs=[prompt, negative, steps, cfg, scheduler, seed],
204
+ )
205
 
206
  demo.launch()