panelforge commited on
Commit
7322075
·
verified ·
1 Parent(s): 83fcdda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -27
app.py CHANGED
@@ -14,7 +14,6 @@ PROMPT_PREFIXES = {
14
  "Straight": "score_9, score_8_up, score_7_up, source_anime, ",
15
  "Lesbian": "score_9, score_8_up, score_7_up, source_anime, ",
16
  "Gay": "score_9, score_8_up, score_7_up, source_anime, yaoi, "
17
- # Add more tabs if needed
18
  }
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -26,38 +25,52 @@ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
28
 
29
- # Create checkbox groups for each tag set
30
  def create_checkboxes(tag_dict, suffix):
31
  categories = list(tag_dict.keys())
32
  return [gr.CheckboxGroup(choices=list(tag_dict[cat].keys()), label=f"{cat} Tags ({suffix})") for cat in categories], categories
33
 
34
- straight_checkboxes, straight_categories = create_checkboxes(TAGS_STRAIGHT, "Straight")
35
- lesbian_checkboxes, lesbian_categories = create_checkboxes(TAGS_LESBIAN, "Lesbian")
36
- gay_checkboxes, gay_categories = create_checkboxes(TAGS_GAY, "Gay")
37
-
38
 
39
  @spaces.GPU
40
- prefix = PROMPT_PREFIXES.get(active_tab, "score_9, score_8_up, score_7_up, source_anime")
41
-
42
- if active_tab == "Prompt Input":
43
- final_prompt = f"{prefix}, {prompt}"
44
- else:
45
- combined_tags = []
46
-
47
- if active_tab == "Straight":
48
- for (tag_name, tag_dict), selected in zip(TAGS_STRAIGHT.items(), tag_selections[:len(TAGS_STRAIGHT)]):
49
- combined_tags.extend([tag_dict[tag] for tag in selected])
50
- elif active_tab == "Lesbian":
51
- offset = len(TAGS_STRAIGHT)
52
- for (tag_name, tag_dict), selected in zip(TAGS_LESBIAN.items(), tag_selections[offset:offset+len(TAGS_LESBIAN)]):
53
- combined_tags.extend([tag_dict[tag] for tag in selected])
54
- elif active_tab == "Gay":
55
- offset = len(TAGS_STRAIGHT) + len(TAGS_LESBIAN)
56
- for (tag_name, tag_dict), selected in zip(TAGS_GAY.items(), tag_selections[offset:offset+len(TAGS_GAY)]):
57
- combined_tags.extend([tag_dict[tag] for tag in selected])
58
-
59
- tag_string = ", ".join(combined_tags)
60
- final_prompt = f"{prefix}, {tag_string}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  image = pipe(
63
  prompt=final_prompt,
 
14
  "Straight": "score_9, score_8_up, score_7_up, source_anime, ",
15
  "Lesbian": "score_9, score_8_up, score_7_up, source_anime, ",
16
  "Gay": "score_9, score_8_up, score_7_up, source_anime, yaoi, "
 
17
  }
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
25
  MAX_SEED = np.iinfo(np.int32).max
26
  MAX_IMAGE_SIZE = 1024
27
 
 
28
  def create_checkboxes(tag_dict, suffix):
29
  categories = list(tag_dict.keys())
30
  return [gr.CheckboxGroup(choices=list(tag_dict[cat].keys()), label=f"{cat} Tags ({suffix})") for cat in categories], categories
31
 
32
+ straight_checkboxes, _ = create_checkboxes(TAGS_STRAIGHT, "Straight")
33
+ lesbian_checkboxes, _ = create_checkboxes(TAGS_LESBIAN, "Lesbian")
34
+ gay_checkboxes, _ = create_checkboxes(TAGS_GAY, "Gay")
 
35
 
36
  @spaces.GPU
37
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
38
+ guidance_scale, num_inference_steps, active_tab, *tag_selections,
39
+ progress=gr.Progress(track_tqdm=True)):
40
+
41
+ prefix = PROMPT_PREFIXES.get(active_tab, "score_9, score_8_up, score_7_up, source_anime")
42
+
43
+ if active_tab == "Prompt Input":
44
+ final_prompt = f"{prefix}, {prompt}"
45
+ else:
46
+ combined_tags = []
47
+
48
+ straight_len = len(TAGS_STRAIGHT)
49
+ lesbian_len = len(TAGS_LESBIAN)
50
+ gay_len = len(TAGS_GAY)
51
+
52
+ if active_tab == "Straight":
53
+ for (tag_name, tag_dict), selected in zip(TAGS_STRAIGHT.items(), tag_selections[:straight_len]):
54
+ combined_tags.extend([tag_dict[tag] for tag in selected])
55
+ elif active_tab == "Lesbian":
56
+ offset = straight_len
57
+ for (tag_name, tag_dict), selected in zip(TAGS_LESBIAN.items(), tag_selections[offset:offset+lesbian_len]):
58
+ combined_tags.extend([tag_dict[tag] for tag in selected])
59
+ elif active_tab == "Gay":
60
+ offset = straight_len + lesbian_len
61
+ for (tag_name, tag_dict), selected in zip(TAGS_GAY.items(), tag_selections[offset:offset+gay_len]):
62
+ combined_tags.extend([tag_dict[tag] for tag in selected])
63
+
64
+ tag_string = ", ".join(combined_tags)
65
+ final_prompt = f"{prefix} {tag_string}"
66
+
67
+ negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
68
+ full_negative_prompt = f"{negative_base}, {negative_prompt}"
69
+
70
+ if randomize_seed:
71
+ seed = random.randint(0, MAX_SEED)
72
+
73
+ generator = torch.Generator().manual_seed(seed)
74
 
75
  image = pipe(
76
  prompt=final_prompt,