HAL1993 commited on
Commit
a301d6c
·
verified ·
1 Parent(s): 987a464

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -154
app.py CHANGED
@@ -4,176 +4,155 @@ import random
4
  import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
-
8
- from tags_straight import TAGS_STRAIGHT
9
- from tags_lesbian import TAGS_LESBIAN
10
- from tags_gay import TAGS_GAY
11
-
12
- PROMPT_PREFIXES = {
13
- "Prompt Input": "score_9, score_8_up, score_7_up, source_anime",
14
- "Straight": "score_9, score_8_up, score_7_up, source_anime, ",
15
- "Lesbian": "score_9, score_8_up, score_7_up, source_anime, ",
16
- "Gay": "score_9, score_8_up, score_7_up, source_anime, yaoi, "
17
- }
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
21
 
22
-
23
-
24
- # model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v8-sdxl"
25
  model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v140-sdxl"
26
-
27
-
28
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
29
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
- MAX_IMAGE_SIZE = 1024
32
-
33
- def create_checkboxes(tag_dict, suffix):
34
- categories = list(tag_dict.keys())
35
- return [gr.CheckboxGroup(choices=list(tag_dict[cat].keys()), label=f"{cat} Tags ({suffix})") for cat in categories], categories
36
-
37
- straight_checkboxes, _ = create_checkboxes(TAGS_STRAIGHT, "Straight")
38
- lesbian_checkboxes, _ = create_checkboxes(TAGS_LESBIAN, "Lesbian")
39
- gay_checkboxes, _ = create_checkboxes(TAGS_GAY, "Gay")
40
 
 
41
  @spaces.GPU
42
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
43
- guidance_scale, num_inference_steps, active_tab, *tag_selections,
44
- progress=gr.Progress(track_tqdm=True)):
45
-
46
- prefix = PROMPT_PREFIXES.get(active_tab, "score_9, score_8_up, score_7_up, source_anime")
47
-
48
- if active_tab == "Prompt Input":
49
- final_prompt = f"{prefix}, {prompt}"
50
- else:
51
- combined_tags = []
52
-
53
- straight_len = len(TAGS_STRAIGHT)
54
- lesbian_len = len(TAGS_LESBIAN)
55
- gay_len = len(TAGS_GAY)
56
-
57
- if active_tab == "Straight":
58
- for (tag_name, tag_dict), selected in zip(TAGS_STRAIGHT.items(), tag_selections[:straight_len]):
59
- combined_tags.extend([tag_dict[tag] for tag in selected])
60
- elif active_tab == "Lesbian":
61
- offset = straight_len
62
- for (tag_name, tag_dict), selected in zip(TAGS_LESBIAN.items(), tag_selections[offset:offset+lesbian_len]):
63
- combined_tags.extend([tag_dict[tag] for tag in selected])
64
- elif active_tab == "Gay":
65
- offset = straight_len + lesbian_len
66
- for (tag_name, tag_dict), selected in zip(TAGS_GAY.items(), tag_selections[offset:offset+gay_len]):
67
- combined_tags.extend([tag_dict[tag] for tag in selected])
68
-
69
- tag_string = ", ".join(combined_tags)
70
- final_prompt = f"{prefix} {tag_string}"
71
-
72
- negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
73
- full_negative_prompt = f"{negative_base}, {negative_prompt}"
74
-
75
- if randomize_seed:
76
- seed = random.randint(0, MAX_SEED)
77
-
 
 
78
  generator = torch.Generator().manual_seed(seed)
79
 
80
  image = pipe(
81
  prompt=final_prompt,
82
- negative_prompt=full_negative_prompt,
83
- guidance_scale=guidance_scale,
84
- num_inference_steps=num_inference_steps,
85
  width=width,
86
  height=height,
87
  generator=generator
88
  ).images[0]
89
 
90
- return image, seed, f"Prompt used: {final_prompt}\nNegative prompt used: {full_negative_prompt}"
91
-
92
-
93
- css = """
94
- #col-container {
95
- margin: 0 auto;
96
- max-width: 1280px;
97
- }
98
-
99
- #left-column {
100
- width: 50%;
101
- display: inline-block;
102
- padding: 20px;
103
- vertical-align: top;
104
- }
105
-
106
- #right-column {
107
- width: 50%;
108
- display: inline-block;
109
- vertical-align: top;
110
- padding: 20px;
111
- margin-top: 53px;
112
- }
113
-
114
- #run-button {
115
- width: 100%;
116
- margin-top: 10px;
117
- }
118
- """
119
-
120
- with gr.Blocks(css=css) as demo:
121
- with gr.Row():
122
- with gr.Column(elem_id="left-column"):
123
- gr.Markdown("# Rainbow Media X")
124
-
125
- result = gr.Image(label="Result", show_label=False)
126
- prompt_info = gr.Textbox(label="Prompts Used", lines=3, interactive=False)
127
-
128
- with gr.Accordion("Advanced Settings", open=False):
129
- negative_prompt = gr.Textbox(label="Negative prompt", placeholder="Enter negative prompt")
130
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
131
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
132
-
133
- with gr.Row():
134
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
135
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
136
-
137
- with gr.Row():
138
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=10, step=0.1, value=7)
139
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=35)
140
-
141
- run_button = gr.Button("Run", elem_id="run-button")
142
-
143
- with gr.Column(elem_id="right-column"):
144
- active_tab = gr.State("Prompt Input")
145
-
146
- with gr.Tabs() as tabs:
147
- with gr.TabItem("Prompt Input") as prompt_tab:
148
- prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt")
149
- prompt_tab.select(lambda: "Prompt Input", outputs=active_tab)
150
-
151
- with gr.TabItem("Straight") as straight_tab:
152
- for cb in straight_checkboxes:
153
- cb.render()
154
- straight_tab.select(lambda: "Straight", outputs=active_tab)
155
-
156
- with gr.TabItem("Lesbian") as lesbian_tab:
157
- for cb in lesbian_checkboxes:
158
- cb.render()
159
- lesbian_tab.select(lambda: "Lesbian", outputs=active_tab)
160
-
161
- with gr.TabItem("Gay") as gay_tab:
162
- for cb in gay_checkboxes:
163
- cb.render()
164
- gay_tab.select(lambda: "Gay", outputs=active_tab)
165
-
166
- run_button.click(
167
- fn=infer,
168
- inputs=[
169
- prompt, negative_prompt, seed, randomize_seed,
170
- width, height, guidance_scale, num_inference_steps,
171
- active_tab,
172
- *straight_checkboxes,
173
- *lesbian_checkboxes,
174
- *gay_checkboxes
175
- ],
176
- outputs=[result, seed, prompt_info]
177
- )
178
-
179
- demo.queue().launch()
 
4
  import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
+ import requests
 
 
 
 
 
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
11
 
 
 
 
12
  model_repo_id = "John6666/wai-ani-nsfw-ponyxl-v140-sdxl"
 
 
13
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
+ MAX_IMAGE_SIZE = 1536
 
 
 
 
 
 
 
 
17
 
18
+ # Translation function
19
  @spaces.GPU
20
+ def translate_albanian_to_english(text):
21
+ if not text.strip():
22
+ return ""
23
+ for attempt in range(2):
24
+ try:
25
+ response = requests.post(
26
+ "https://hal1993-mdftranslation1234567890abcdef1234567890-fc073a6.hf.space/v1/translate",
27
+ json={"from_language": "sq", "to_language": "en", "input_text": text},
28
+ headers={"accept": "application/json", "Content-Type": "application/json"},
29
+ timeout=5
30
+ )
31
+ response.raise_for_status()
32
+ translated = response.json().get("translate", "")
33
+ return translated
34
+ except Exception as e:
35
+ if attempt == 1:
36
+ raise gr.Error(f"Përkthimi dështoi: {str(e)}")
37
+ raise gr.Error("Përkthimi dështoi. Ju lutem provoni përsëri.")
38
+
39
+ # Aspect ratio function
40
+ def update_aspect_ratio(ratio):
41
+ if ratio == "1:1":
42
+ return 1024, 1024
43
+ elif ratio == "9:16":
44
+ return 576, 1024
45
+ elif ratio == "16:9":
46
+ return 1024, 576
47
+ return 1024, 1024
48
+
49
+ @spaces.GPU(duration=120)
50
+ def infer(prompt, width, height, progress=gr.Progress(track_tqdm=True)):
51
+ # Translate prompt
52
+ final_prompt = translate_albanian_to_english(prompt.strip()) if prompt.strip() else ""
53
+ final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {final_prompt}"
54
+
55
+ negative_prompt = "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn, (deformed | distorted | disfigured:1.3), bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers:1.4, disconnected limbs, blurry, amputation"
56
+
57
+ seed = random.randint(0, MAX_SEED)
58
  generator = torch.Generator().manual_seed(seed)
59
 
60
  image = pipe(
61
  prompt=final_prompt,
62
+ negative_prompt=negative_prompt,
63
+ guidance_scale=7,
64
+ num_inference_steps=60,
65
  width=width,
66
  height=height,
67
  generator=generator
68
  ).images[0]
69
 
70
+ return image
71
+
72
+ def create_demo():
73
+ with gr.Blocks() as demo:
74
+ # CSS for 320px gap and download button scaling
75
+ gr.HTML("""
76
+ <style>
77
+ body::before {
78
+ content: "";
79
+ display: block;
80
+ height: 320px;
81
+ background-color: var(--body-background-fill);
82
+ }
83
+ button[aria-label="Fullscreen"], button[aria-label="Fullscreen"]:hover {
84
+ display: none !important;
85
+ visibility: hidden !important;
86
+ opacity: 0 !important;
87
+ pointer-events: none !important;
88
+ }
89
+ button[aria-label="Share"], button[aria-label="Share"]:hover {
90
+ display: none !important;
91
+ }
92
+ button[aria-label="Download"] {
93
+ transform: scale(3);
94
+ transform-origin: top right;
95
+ margin: 0 !important;
96
+ padding: 6px !important;
97
+ }
98
+ </style>
99
+ """)
100
+
101
+ gr.Markdown("# Krijo Media Rainbow")
102
+ gr.Markdown("Gjenero imazhe të reja nga përshkrimi yt me fuqinë e inteligjencës artificiale.")
103
+
104
+ with gr.Column():
105
+ prompt = gr.Textbox(
106
+ label="Përshkrimi",
107
+ placeholder="Shkruani përshkrimin këtu",
108
+ lines=3
109
+ )
110
+ aspect_ratio = gr.Radio(
111
+ label="Raporti i fotos",
112
+ choices=["9:16", "1:1", "16:9"],
113
+ value="1:1"
114
+ )
115
+ generate_button = gr.Button(value="Gjenero")
116
+ result_image = gr.Image(
117
+ label="Imazhi i Gjeneruar",
118
+ interactive=False
119
+ )
120
+
121
+ # Hidden sliders for width and height
122
+ width_slider = gr.Slider(
123
+ value=1024,
124
+ minimum=256,
125
+ maximum=MAX_IMAGE_SIZE,
126
+ step=8,
127
+ visible=False
128
+ )
129
+ height_slider = gr.Slider(
130
+ value=1024,
131
+ minimum=256,
132
+ maximum=MAX_IMAGE_SIZE,
133
+ step=8,
134
+ visible=False
135
+ )
136
+
137
+ # Update hidden sliders based on aspect ratio
138
+ aspect_ratio.change(
139
+ fn=update_aspect_ratio,
140
+ inputs=[aspect_ratio],
141
+ outputs=[width_slider, height_slider],
142
+ queue=False
143
+ )
144
+
145
+ # Bind the generate button
146
+ generate_button.click(
147
+ fn=infer,
148
+ inputs=[prompt, width_slider, height_slider],
149
+ outputs=[result_image],
150
+ show_progress="full"
151
+ )
152
+
153
+ return demo
154
+
155
+ if __name__ == "__main__":
156
+ print(f"Gradio version: {gr.__version__}")
157
+ app = create_demo()
158
+ app.queue(max_size=12).launch(server_name='0.0.0.0')