salso commited on
Commit
bc870ea
Β·
verified Β·
1 Parent(s): db86bb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -148
app.py CHANGED
@@ -1,154 +1,162 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # ZenCtrl Inpainting Playground (Baseten backend)
3
+
4
+ import os, json, base64, requests
5
+ from io import BytesIO
6
+ from PIL import Image, ImageDraw
7
  import gradio as gr
8
+
9
+ # ────────── Secrets & endpoints ──────────
10
+ BASETEN_MODEL_URL = "https://app.baseten.co/models/YOUR_MODEL_ID/predict"
11
+ BASETEN_API_KEY = os.getenv("BASETEN_API_KEY")
12
+ REPLICATE_TOKEN = os.getenv("REPLICATE_API_TOKEN")
13
+
14
+ from florence_sam.detect_and_segment import fill_detected_bboxes
15
+
16
+ # ────────── Globals ──────────
17
+ ADAPTER_NAME = "inpaint"
18
+ ADAPTER_SIZE = 1024
19
+ model_config = dict(union_cond_attn=True, add_cond_attn=False,
20
+ latent_lora=False, independent_condition=False)
21
+ css = "#col-container {margin:0 auto; max-width:960px;}"
22
+
23
+ #Background prompt via Replicate
24
+ def _gen_bg(prompt: str):
25
+ url = replicate.run(
26
+ "google/imagen-4-fast",
27
+ input={"prompt": prompt or "cinematic background", "aspect_ratio": "1:1"},
28
+ )
29
+ url = url[0] if isinstance(url, list) else url
30
+ return Image.open(BytesIO(requests.get(url, timeout=120).content)).convert("RGB")
31
+
32
+ # Core generation
33
+ def process_image_and_text(subject_image, adapter_dict, prompt, use_detect, detect_prompt, size=ADAPTER_SIZE, rank=10.0):
34
+ seed, guidance_scale, steps = 42, 2.5, 28
35
+
36
+ if use_detect:
37
+ base_img = adapter_dict["image"] if isinstance(adapter_dict, dict) else adapter_dict
38
+ if base_img is None:
39
+ raise gr.Error("Upload a background image first.")
40
+ adapter_image, _ = fill_detected_bboxes(
41
+ image=base_img, text=detect_prompt,
42
+ inflate_pct=0.15, fill_color="#00FF00"
43
+ )
44
+ else:
45
+ adapter_image = adapter_dict["image"] if isinstance(adapter_dict, dict) else adapter_dict
46
+ if isinstance(adapter_dict, dict) and adapter_dict.get("mask") is not None:
47
+ m = adapter_dict["mask"].convert("L").point(lambda p: 255 if p else 0)
48
+ if bbox := m.getbbox():
49
+ rect = Image.new("L", m.size, 0)
50
+ ImageDraw.Draw(rect).rectangle(bbox, fill=255)
51
+ m = rect
52
+ green = Image.new("RGB", adapter_image.size, "#00FF00")
53
+ adapter_image = Image.composite(green, adapter_image, m)
54
+
55
+ def prep(img: Image.Image):
56
+ w, h = img.size
57
+ m = min(w, h)
58
+ return img.crop(((w-m)//2, (h-m)//2, (w+m)//2, (h+m)//2)).resize((size, size), Image.LANCZOS)
59
+
60
+ subj_proc = prep(subject_image)
61
+ adap_proc = prep(adapter_image)
62
+
63
+ def b64(img):
64
+ buf = BytesIO(); img.save(buf, format="PNG")
65
+ return base64.b64encode(buf.getvalue()).decode()
66
+
67
+ payload = {
68
+ "prompt": prompt,
69
+ "subject_image": b64(subj_proc),
70
+ "adapter_image": b64(adap_proc),
71
+ "height": size, "width": size,
72
+ "steps": steps, "seed": seed,
73
+ "guidance_scale": guidance_scale, "rank": rank,
74
+ }
75
+
76
+ headers = {"Content-Type": "application/json"}
77
+ if BASETEN_API_KEY:
78
+ headers["Authorization"] = f"Api-Key {BASETEN_API_KEY}"
79
+
80
+ resp = requests.post(BASETEN_MODEL_URL, headers=headers, json=payload, timeout=120)
81
+ resp.raise_for_status()
82
+
83
+ if resp.headers.get("content-type", "").startswith("image/"):
84
+ raw_img = Image.open(BytesIO(resp.content))
85
+ else:
86
+ url = resp.json().get("image_url")
87
+ if not url:
88
+ raise gr.Error("Baseten response missing image data.")
89
+ raw_img = Image.open(BytesIO(requests.get(url, timeout=120).content))
90
+
91
+ return [[raw_img]], raw_img
92
+
93
+ # ────────── Header HTML ──────────
94
+ header_html = """
95
+ <h1>ZenCtrl Inpainting</h1>
96
+ <div align="center" style="line-height:1;">
97
+ <a href="https://discord.com/invite/b9RuYQ3F8k" target="_blank" style="margin:10px;">
98
+ <img src="https://img.shields.io/badge/Discord-Join-7289da.svg?logo=discord" alt="Discord">
99
+ </a>
100
+ <a href="https://fotographer.ai/zen-control" target="_blank" style="margin:10px;">
101
+ <img src="https://img.shields.io/badge/Website-Landing_Page-blue" alt="LP">
102
+ </a>
103
+ <a href="https://x.com/FotographerAI" target="_blank" style="margin:10px;">
104
+ <img src="https://img.shields.io/twitter/follow/FotographerAI?style=social" alt="X">
105
+ </a>
106
+ </div>
107
  """
108
 
109
+ # ────────── Gradio UI ──────────
110
+ with gr.Blocks(css=css, title="ZenCtrl Playground") as demo:
111
+ raw_state = gr.State()
112
+
113
+ gr.HTML(header_html)
114
+ gr.Markdown("""
115
+ **Generate context-aware images of your subject with ZenCtrl’s inpainting playground.**
116
+ Upload a subject + optional mask, write a prompt, and hit **Generate**.
117
+ Open *Advanced Settings* to fetch an AI-generated background.
118
+ """)
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=2, elem_id="col-container"):
122
+ subj_img = gr.Image(type="pil", label="Subject image")
123
+ ref_img = gr.Image(type="pil", label="Background / Mask image", tool="sketch", brush_color="#00FF00", sources=["upload", "clipboard"])
124
+ use_detect_ck = gr.Checkbox(False, label="Detect with Florence-SAM")
125
+ detect_box = gr.Textbox(label="Detection prompt", value="person, chair", visible=False)
126
+ promptbox = gr.Textbox(label="Generation prompt", value="furniture", lines=2)
127
+ run_btn = gr.Button("Generate", variant="primary")
128
+
129
+ with gr.Accordion("Advanced Settings", open=False):
130
+ bgprompt = gr.Textbox(label="Background Prompt", value="Scandinavian living room …")
131
+ bg_btn = gr.Button("Generate BG")
132
+
133
+ with gr.Column(scale=2):
134
+ gallery = gr.Gallery(columns=[1], rows=[1], object_fit="contain", height="auto")
135
+ bg_img = gr.Image(label="Background", visible=False)
136
+
137
+ gr.Examples(
138
+ examples=[
139
+ ["examples/subject1.png", "examples/bg1.png", "Make the toy sit on a marble table", "examples/out1.png"],
140
+ ["examples/subject2.png", "examples/bg2.png", "Turn the flowers into sunflowers", "examples/out2.png"],
141
+ ["examples/subject3.png", "examples/bg3.png", "Make this monster ride a skateboard on the beach", "examples/out3.png"],
142
+ ["examples/subject4.png", "examples/bg4.png", "Make this cat happy", "examples/out4.png"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  ],
144
+ inputs=[subj_img, ref_img, promptbox],
145
+ outputs=[gallery],
146
+ fn=process_image_and_text,
147
+ examples_per_page="all",
148
+ label="Presets (Input Β· Background Β· Prompt Β· Output)",
149
+ cache_examples="lazy"
150
  )
151
 
152
+ run_btn.click(
153
+ process_image_and_text,
154
+ inputs=[subj_img, ref_img, promptbox, use_detect_ck, detect_box],
155
+ outputs=[gallery, raw_state]
156
+ )
157
+
158
+ bg_btn.click(_gen_bg, inputs=[bgprompt], outputs=[bg_img])
159
+ use_detect_ck.change(lambda v: gr.update(visible=v), inputs=use_detect_ck, outputs=detect_box)
160
+
161
+ # ────────── Launch ──────────
162
+ demo.launch(show_api=False, share=True)