panelforge commited on
Commit
5f66166
Β·
verified Β·
1 Parent(s): 325abfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -44
app.py CHANGED
@@ -1,40 +1,40 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import random
6
  import torch
7
- import spaces # Hugging Face Spaces ZeroGPU support
8
  from diffusers import DiffusionPipeline
9
- from tags import TAGS
10
 
11
- # Model loading
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
14
- model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
22
- # ⬇ Inference Function with @spaces.GPU
 
 
 
23
  @spaces.GPU
24
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
25
- guidance_scale, num_inference_steps, *tag_selections, active_tab, progress=gr.Progress(track_tqdm=True)):
 
26
 
27
  if active_tab == "Prompt Input":
28
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
29
  else:
30
- all_tags = []
31
- for (group_name, tag_dict), selected_keys in zip(TAGS.items(), tag_selections):
32
- all_tags.extend([tag_dict[key] for key in selected_keys])
33
- tag_text = ", ".join(all_tags)
34
- final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_text}"
35
 
36
- additional_negatives = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
37
- full_negative_prompt = f"{additional_negatives}, {negative_prompt}"
38
 
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
@@ -53,13 +53,35 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
53
 
54
  return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
55
 
56
- # ⬇ Gradio UI
57
- with gr.Blocks(css="""
58
- #col-container { max-width: 1280px; margin: auto; }
59
- #left-column, #right-column { display: inline-block; vertical-align: top; width: 48%; padding: 1%; }
60
- #run-button { width: 100%; }
61
- """) as demo:
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  with gr.Row():
64
  with gr.Column(elem_id="left-column"):
65
  gr.Markdown("# Rainbow Media X")
@@ -68,16 +90,16 @@ with gr.Blocks(css="""
68
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)
69
 
70
  with gr.Accordion("Advanced Settings", open=False):
71
- negative_prompt = gr.Textbox(label="Negative Prompt", max_lines=1, placeholder="Enter a negative prompt")
72
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
73
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
74
 
75
  with gr.Row():
76
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=768)
77
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=768)
78
 
79
  with gr.Row():
80
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.0)
81
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=35)
82
 
83
  run_button = gr.Button("Run", elem_id="run-button")
@@ -87,26 +109,23 @@ with gr.Blocks(css="""
87
 
88
  with gr.Tabs() as tabs:
89
  with gr.TabItem("Prompt Input") as prompt_tab:
90
- prompt = gr.Textbox(label="Prompt", placeholder="Enter a custom prompt", lines=3)
91
  prompt_tab.select(lambda: "Prompt Input", outputs=active_tab)
92
 
93
  with gr.TabItem("Tag Selection") as tag_tab:
94
- tag_checkboxes = []
95
- for group_name, tag_dict in TAGS.items():
96
- checkbox = gr.CheckboxGroup(choices=list(tag_dict.keys()), label=group_name)
97
- tag_checkboxes.append(checkbox)
98
  tag_tab.select(lambda: "Tag Selection", outputs=active_tab)
99
 
100
- run_button.click(
101
- fn=infer,
102
- inputs=[
103
- prompt, negative_prompt, seed, randomize_seed,
104
- width, height, guidance_scale, num_inference_steps,
105
- active_tab, # Moved BEFORE *tag_checkboxes
106
- *tag_checkboxes
107
- ],
108
- outputs=[result, seed, prompt_info]
109
- )
110
-
111
 
112
  demo.queue().launch()
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  import torch
5
+ import spaces
6
  from diffusers import DiffusionPipeline
7
+ from tags import TAGS # Centralized dictionary
8
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
 
11
 
12
+ model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
13
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 1024
17
 
18
+ # Prepare keys for each tag category for UI and loop usage
19
+ tag_categories = list(TAGS.keys()) # e.g. ["Participant", "Tribe", "Skin Tone", ...]
20
+ tag_checkboxes = [gr.CheckboxGroup(choices=list(TAGS[k].keys()), label=f"{k} Tags") for k in tag_categories]
21
+
22
  @spaces.GPU
23
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
24
+ guidance_scale, num_inference_steps, active_tab, *tag_selections,
25
+ progress=gr.Progress(track_tqdm=True)):
26
 
27
  if active_tab == "Prompt Input":
28
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
29
  else:
30
+ combined_tags = []
31
+ for (tag_name, tag_dict), selected in zip(TAGS.items(), tag_selections):
32
+ combined_tags.extend([tag_dict[tag] for tag in selected])
33
+ tag_string = ", ".join(combined_tags)
34
+ final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_string}"
35
 
36
+ negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
37
+ full_negative_prompt = f"{negative_base}, {negative_prompt}"
38
 
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
 
53
 
54
  return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
55
 
 
 
 
 
 
 
56
 
57
+ css = """
58
+ #col-container {
59
+ margin: 0 auto;
60
+ max-width: 1280px;
61
+ }
62
+
63
+ #left-column {
64
+ width: 50%;
65
+ display: inline-block;
66
+ padding: 20px;
67
+ vertical-align: top;
68
+ }
69
+
70
+ #right-column {
71
+ width: 50%;
72
+ display: inline-block;
73
+ vertical-align: top;
74
+ padding: 20px;
75
+ margin-top: 53px;
76
+ }
77
+
78
+ #run-button {
79
+ width: 100%;
80
+ margin-top: 10px;
81
+ }
82
+ """
83
+
84
+ with gr.Blocks(css=css) as demo:
85
  with gr.Row():
86
  with gr.Column(elem_id="left-column"):
87
  gr.Markdown("# Rainbow Media X")
 
90
  prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)
91
 
92
  with gr.Accordion("Advanced Settings", open=False):
93
+ negative_prompt = gr.Textbox(label="Negative prompt", placeholder="Enter negative prompt")
94
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
95
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
96
 
97
  with gr.Row():
98
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
99
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
100
 
101
  with gr.Row():
102
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=10, step=0.1, value=7)
103
  num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=35)
104
 
105
  run_button = gr.Button("Run", elem_id="run-button")
 
109
 
110
  with gr.Tabs() as tabs:
111
  with gr.TabItem("Prompt Input") as prompt_tab:
112
+ prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt")
113
  prompt_tab.select(lambda: "Prompt Input", outputs=active_tab)
114
 
115
  with gr.TabItem("Tag Selection") as tag_tab:
116
+ for tag_box in tag_checkboxes:
117
+ tag_box.render()
 
 
118
  tag_tab.select(lambda: "Tag Selection", outputs=active_tab)
119
 
120
+ run_button.click(
121
+ fn=infer,
122
+ inputs=[
123
+ prompt, negative_prompt, seed, randomize_seed,
124
+ width, height, guidance_scale, num_inference_steps,
125
+ active_tab,
126
+ *tag_checkboxes
127
+ ],
128
+ outputs=[result, seed, prompt_info]
129
+ )
 
130
 
131
  demo.queue().launch()