John6666 commited on
Commit
07302a5
Β·
verified Β·
1 Parent(s): 862d3ae

Upload 10 files

Browse files
Files changed (7) hide show
  1. README.md +2 -2
  2. app.py +8 -2
  3. fl2basepromptgen.py +9 -3
  4. fl2flux.py +90 -0
  5. fl2sd3longcap.py +9 -3
  6. promptenhancer.py +22 -5
  7. tagger.py +11 -4
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Prompt Enhancer with WD Tagger & Florence 2 SD3 Captioner
3
  emoji: πŸƒπŸ“¦
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Prompt Enhancer with WD Tagger & Florence 2 Flux/SD3 Captioner
3
  emoji: πŸƒπŸ“¦
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -16,12 +16,13 @@ from tagger import (
16
  )
17
  from fl2sd3longcap import predict_tags_fl2_sd3
18
  from fl2basepromptgen import predict_tags_fl2_base_prompt_gen
 
19
  from promptenhancer import prompt_enhancer
20
 
21
  def description_ui():
22
  gr.Markdown(
23
  """
24
- ## Prompt Enhancer with WD Tagger & SD3 Long Captioner
25
  (Image =>) Prompt => Upsampled longer prompt
26
  """
27
  )
@@ -33,8 +34,11 @@ def description_ui2():
33
  [Florence-2-SD3-Captioner](https://huggingface.co/spaces/gokaygokay/Florence-2-SD3-Captioner).
34
  - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf),\
35
  gokaygokay's [Florence-2-SD3-Captioner](https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner),\
 
 
36
  [Lamini-Prompt-Enchance](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance),\
37
  [Lamini-Prompt-Enchance-Long](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long),\
 
38
  MiaoshouAI's [Florence-2-base-PromptGen](https://huggingface.co/MiaoshouAI/Florence-2-base-PromptGen).
39
  """
40
  )
@@ -51,7 +55,7 @@ def main():
51
  input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
52
  recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
53
  keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
54
- image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner", "Use Florence-2-base-PromptGen"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
55
  generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
56
  with gr.Group():
57
  with gr.Row():
@@ -98,6 +102,8 @@ def main():
98
  predict_tags_fl2_base_prompt_gen,
99
  [input_image, input_general, image_algorithms],
100
  [input_general],
 
 
101
  ).success(
102
  remove_specific_prompt, [input_general, keep_tags], [input_general], queue=False,
103
  ).success(
 
16
  )
17
  from fl2sd3longcap import predict_tags_fl2_sd3
18
  from fl2basepromptgen import predict_tags_fl2_base_prompt_gen
19
+ from fl2flux import predict_tags_fl2_flux
20
  from promptenhancer import prompt_enhancer
21
 
22
  def description_ui():
23
  gr.Markdown(
24
  """
25
+ ## Prompt Enhancer with WD Tagger & Flux/SD3 Captioner
26
  (Image =>) Prompt => Upsampled longer prompt
27
  """
28
  )
 
34
  [Florence-2-SD3-Captioner](https://huggingface.co/spaces/gokaygokay/Florence-2-SD3-Captioner).
35
  - Models: p1atdev's [wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf),\
36
  gokaygokay's [Florence-2-SD3-Captioner](https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner),\
37
+ gokaygokay's [Florence-2-Flux](https://huggingface.co/gokaygokay/Florence-2-Flux),\
38
+ gokaygokay's [Florence-2-Flux-Large](https://huggingface.co/gokaygokay/Florence-2-Flux-Large),\
39
  [Lamini-Prompt-Enchance](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance),\
40
  [Lamini-Prompt-Enchance-Long](https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long),\
41
+ [Flux-Prompt-Enhance](https://huggingface.co/gokaygokay/Flux-Prompt-Enhance),\
42
  MiaoshouAI's [Florence-2-base-PromptGen](https://huggingface.co/MiaoshouAI/Florence-2-base-PromptGen).
43
  """
44
  )
 
55
  input_tag_type = gr.Radio(label="Convert tags to", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru")
56
  recom_prompt = gr.Radio(label="Insert reccomended prompt", choices=["None", "Animagine", "Pony"], value="None", interactive=True)
57
  keep_tags = gr.Radio(label="Remove tags leaving only the following", choices=["body", "dress", "all"], value="all")
58
+ image_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner", "Use Florence-2-base-PromptGen", "Use Florence-2-Flux","Use Florence-2-Flux-Large"], label="Algorithms", value=["Use WD Tagger", "Use Florence-2-SD3-Long-Captioner"])
59
  generate_from_image_btn = gr.Button(value="GENERATE TAGS FROM IMAGE", size="lg", variant="primary")
60
  with gr.Group():
61
  with gr.Row():
 
102
  predict_tags_fl2_base_prompt_gen,
103
  [input_image, input_general, image_algorithms],
104
  [input_general],
105
+ ).success(
106
+ predict_tags_fl2_flux, [input_image, input_general, image_algorithms], [input_general],
107
  ).success(
108
  remove_specific_prompt, [input_general, keep_tags], [input_general], queue=False,
109
  ).success(
fl2basepromptgen.py CHANGED
@@ -7,11 +7,15 @@ import subprocess
7
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True).to(device).eval()
11
- fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
12
 
 
 
 
 
 
 
13
 
14
- @spaces.GPU
15
  def fl_run(image):
16
  task_prompt = "<GENERATE_PROMPT>"
17
  prompt = task_prompt + "Describe this image in great detail."
@@ -20,6 +24,7 @@ def fl_run(image):
20
  if image.mode != "RGB":
21
  image = image.convert("RGB")
22
 
 
23
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
24
  generated_ids = fl_model.generate(
25
  input_ids=inputs["input_ids"],
@@ -28,6 +33,7 @@ def fl_run(image):
28
  do_sample=False,
29
  num_beams=3
30
  )
 
31
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
32
  parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
33
  return parsed_answer["<GENERATE_PROMPT>Describe this image in great detail."]
 
7
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
10
 
11
+ try:
12
+ fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True).to("cpu").eval()
13
+ fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
14
+ except Exception as e:
15
+ print(e)
16
+ fl_model = fl_processor = None
17
 
18
+ @spaces.GPU(duration=30)
19
  def fl_run(image):
20
  task_prompt = "<GENERATE_PROMPT>"
21
  prompt = task_prompt + "Describe this image in great detail."
 
24
  if image.mode != "RGB":
25
  image = image.convert("RGB")
26
 
27
+ fl_model.to(device)
28
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
29
  generated_ids = fl_model.generate(
30
  input_ids=inputs["input_ids"],
 
33
  do_sample=False,
34
  num_beams=3
35
  )
36
+ fl_model.to("cpu")
37
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
38
  parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
39
  return parsed_answer["<GENERATE_PROMPT>Describe this image in great detail."]
fl2flux.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM
2
+ import spaces
3
+ import re
4
+ from PIL import Image
5
+ import torch
6
+
7
+ import subprocess
8
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ try:
13
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True).to("cpu").eval()
14
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True)
15
+ fl_model_large = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True).to("cpu").eval()
16
+ fl_processor_large = AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True)
17
+ except Exception as e:
18
+ fl_model = fl_processor = fl_model_large = fl_processor_large = None
19
+ print(e)
20
+
21
+ def fl_modify_caption(caption: str) -> str:
22
+ """
23
+ Removes specific prefixes from captions if present, otherwise returns the original caption.
24
+ Args:
25
+ caption (str): A string containing a caption.
26
+ Returns:
27
+ str: The caption with the prefix removed if it was present, or the original caption.
28
+ """
29
+ # Define the prefixes to remove
30
+ prefix_substrings = [
31
+ ('captured from ', ''),
32
+ ('captured at ', '')
33
+ ]
34
+
35
+ # Create a regex pattern to match any of the prefixes
36
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
37
+ replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
38
+
39
+ # Function to replace matched prefix with its corresponding replacement
40
+ def replace_fn(match):
41
+ return replacers[match.group(0).lower()]
42
+
43
+ # Apply the regex to the caption
44
+ modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
45
+
46
+ # If the caption was modified, return the modified version; otherwise, return the original
47
+ return modified_caption if modified_caption != caption else caption
48
+
49
+
50
+ @spaces.GPU(duration=30)
51
+ def fl_run_example(image, algo):
52
+ task_prompt = "<DESCRIPTION>"
53
+ prompt = task_prompt + "Describe this image in great detail."
54
+ #prompt = task_prompt
55
+
56
+ # Ensure the image is in RGB mode
57
+ if image.mode != "RGB": image = image.convert("RGB")
58
+
59
+ if algo == "Use Florence-2-Flux-Large":
60
+ model = fl_model_large
61
+ processor = fl_processor_large
62
+ else:
63
+ model = fl_model
64
+ processor = fl_processor
65
+ model.to(device)
66
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
67
+ generated_ids = model.generate(
68
+ input_ids=inputs["input_ids"],
69
+ pixel_values=inputs["pixel_values"],
70
+ max_new_tokens=1024,
71
+ num_beams=3
72
+ )
73
+ model.to("cpu")
74
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
75
+ parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
76
+ return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
77
+
78
+
79
+ def predict_tags_fl2_flux(image: Image.Image, input_tags: str, algo: list[str]):
80
+ def to_list(s):
81
+ return [x.strip() for x in s.split(",") if not s == ""]
82
+
83
+ def list_uniq(l):
84
+ return sorted(set(l), key=l.index)
85
+
86
+ if "Use Florence-2-Flux" not in algo and "Use Florence-2-Flux-Large" not in algo:
87
+ return input_tags
88
+ tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image, algo) + ", "))
89
+ tag_list.remove("")
90
+ return ", ".join(tag_list)
fl2sd3longcap.py CHANGED
@@ -8,9 +8,13 @@ import subprocess
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
12
- fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
13
 
 
 
 
 
 
 
14
 
15
  def fl_modify_caption(caption: str) -> str:
16
  """
@@ -41,7 +45,7 @@ def fl_modify_caption(caption: str) -> str:
41
  return modified_caption if modified_caption != caption else caption
42
 
43
 
44
- @spaces.GPU
45
  def fl_run_example(image):
46
  task_prompt = "<DESCRIPTION>"
47
  prompt = task_prompt + "Describe this image in great detail."
@@ -50,6 +54,7 @@ def fl_run_example(image):
50
  if image.mode != "RGB":
51
  image = image.convert("RGB")
52
 
 
53
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
54
  generated_ids = fl_model.generate(
55
  input_ids=inputs["input_ids"],
@@ -57,6 +62,7 @@ def fl_run_example(image):
57
  max_new_tokens=1024,
58
  num_beams=3
59
  )
 
60
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
61
  parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
62
  return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
 
8
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
12
+ try:
13
+ fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to("cpu").eval()
14
+ fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
15
+ except Exception as e:
16
+ print(e)
17
+ fl_model = fl_processor = None
18
 
19
  def fl_modify_caption(caption: str) -> str:
20
  """
 
45
  return modified_caption if modified_caption != caption else caption
46
 
47
 
48
+ @spaces.GPU(duration=30)
49
  def fl_run_example(image):
50
  task_prompt = "<DESCRIPTION>"
51
  prompt = task_prompt + "Describe this image in great detail."
 
54
  if image.mode != "RGB":
55
  image = image.convert("RGB")
56
 
57
+ fl_model.to(device)
58
  inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
59
  generated_ids = fl_model.generate(
60
  input_ids=inputs["input_ids"],
 
62
  max_new_tokens=1024,
63
  num_beams=3
64
  )
65
+ fl_model.to("cpu")
66
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
67
  parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
68
  return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
promptenhancer.py CHANGED
@@ -1,22 +1,32 @@
1
  import spaces
2
  import gradio as gr
3
- from transformers import pipeline
4
  import re
5
  import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  def load_models():
10
- enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
11
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
12
- return enhancer_medium, enhancer_long
 
 
 
 
 
 
 
 
13
 
14
- enhancer_medium, enhancer_long = load_models()
15
 
16
  @spaces.GPU
17
  def enhance_prompt(input_prompt, model_choice):
18
  if model_choice == "Medium":
 
19
  result = enhancer_medium("Enhance the description: " + input_prompt)
 
20
  enhanced_text = result[0]['summary_text']
21
 
22
  pattern = r'^.*?of\s+(.*?(?:\.|$))'
@@ -26,8 +36,15 @@ def enhance_prompt(input_prompt, model_choice):
26
  remaining_text = enhanced_text[match.end():].strip()
27
  modified_sentence = match.group(1).capitalize()
28
  enhanced_text = modified_sentence + ' ' + remaining_text
 
 
 
 
 
29
  else: # Long
 
30
  result = enhancer_long("Enhance the description: " + input_prompt)
 
31
  enhanced_text = result[0]['summary_text']
32
 
33
  return enhanced_text
 
1
  import spaces
2
  import gradio as gr
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
  import re
5
  import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  def load_models():
10
+ try:
11
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device="cpu")
12
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device="cpu")
13
+ model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
15
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device="cpu")
16
+ enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device="cpu")
17
+ except Exception as e:
18
+ print(e)
19
+ enhancer_medium = enhancer_long = enhancer_flux = None
20
+ return enhancer_medium, enhancer_long, enhancer_flux
21
 
22
+ enhancer_medium, enhancer_long, enhancer_flux = load_models()
23
 
24
  @spaces.GPU
25
  def enhance_prompt(input_prompt, model_choice):
26
  if model_choice == "Medium":
27
+ enhancer_medium.to(device=device)
28
  result = enhancer_medium("Enhance the description: " + input_prompt)
29
+ enhancer_medium.to(device="cpu")
30
  enhanced_text = result[0]['summary_text']
31
 
32
  pattern = r'^.*?of\s+(.*?(?:\.|$))'
 
36
  remaining_text = enhanced_text[match.end():].strip()
37
  modified_sentence = match.group(1).capitalize()
38
  enhanced_text = modified_sentence + ' ' + remaining_text
39
+ elif model_choice == "Flux":
40
+ enhancer_flux.to(device=device)
41
+ result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
42
+ enhancer_flux.to(device="cpu")
43
+ enhanced_text = result[0]['generated_text']
44
  else: # Long
45
+ enhancer_long.to(device=device)
46
  result = enhancer_long("Enhance the description: " + input_prompt)
47
+ enhancer_long.to(device="cpu")
48
  enhanced_text = result[0]['summary_text']
49
 
50
  return enhanced_text
tagger.py CHANGED
@@ -12,10 +12,15 @@ from pathlib import Path
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
- wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
16
- wd_model.to("cuda" if torch.cuda.is_available() else "cpu")
17
- wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
18
 
 
 
 
 
 
 
19
 
20
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
21
  return (
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
506
  return ", ".join(all_tags)
507
 
508
 
509
- @spaces.GPU()
510
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
511
  inputs = wd_processor.preprocess(image, return_tensors="pt")
512
 
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
514
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
515
 
516
  # get probabilities
 
517
  results = {
518
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
519
  }
 
520
  # rating, character, general
521
  rating, character, general = postprocess_results(
522
  results, general_threshold, character_threshold
 
12
  WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
13
  WD_MODEL_NAME = WD_MODEL_NAMES[0]
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ default_device = device
 
17
 
18
+ try:
19
+ wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
20
+ wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
21
+ except Exception as e:
22
+ print(e)
23
+ wd_model = wd_processor = None
24
 
25
  def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
26
  return (
 
511
  return ", ".join(all_tags)
512
 
513
 
514
+ @spaces.GPU(duration=30)
515
  def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
516
  inputs = wd_processor.preprocess(image, return_tensors="pt")
517
 
 
519
  logits = torch.sigmoid(outputs.logits[0]) # take the first logits
520
 
521
  # get probabilities
522
+ if device != default_device: wd_model.to(device=device)
523
  results = {
524
  wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
525
  }
526
+ if device != default_device: wd_model.to(device=default_device)
527
  # rating, character, general
528
  rating, character, general = postprocess_results(
529
  results, general_threshold, character_threshold