John6666 commited on
Commit
e253cd4
·
verified ·
1 Parent(s): 36a9d8c

Upload fl2basepromptgen.py

Browse files
Files changed (1) hide show
  1. fl2basepromptgen.py +3 -34
fl2basepromptgen.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
3
- import re
4
  from PIL import Image
5
 
6
  #import subprocess
@@ -10,37 +9,8 @@ fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-Prom
10
  fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
11
 
12
 
13
- def fl_modify_caption(caption: str) -> str:
14
- """
15
- Removes specific prefixes from captions if present, otherwise returns the original caption.
16
- Args:
17
- caption (str): A string containing a caption.
18
- Returns:
19
- str: The caption with the prefix removed if it was present, or the original caption.
20
- """
21
- # Define the prefixes to remove
22
- prefix_substrings = [
23
- ('captured from ', ''),
24
- ('captured at ', '')
25
- ]
26
-
27
- # Create a regex pattern to match any of the prefixes
28
- pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
29
- replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
30
-
31
- # Function to replace matched prefix with its corresponding replacement
32
- def replace_fn(match):
33
- return replacers[match.group(0).lower()]
34
-
35
- # Apply the regex to the caption
36
- modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
37
-
38
- # If the caption was modified, return the modified version; otherwise, return the original
39
- return modified_caption if modified_caption != caption else caption
40
-
41
-
42
  @spaces.GPU
43
- def fl_run_example(image):
44
  task_prompt = "<GENERATE_PROMPT>"
45
  prompt = task_prompt + "Describe this image in great detail."
46
 
@@ -58,8 +28,7 @@ def fl_run_example(image):
58
  )
59
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
60
  parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
61
- print(parsed_answer)
62
- return fl_modify_caption(parsed_answer["<PROMPT>"])
63
 
64
 
65
  def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo: list[str]):
@@ -71,6 +40,6 @@ def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo:
71
 
72
  if not "Use Florence-2-base-PromptGen" in algo:
73
  return input_tags
74
- tag_list = list_uniq(to_list(input_tags) + to_list(fl_run_example(image) + ", "))
75
  tag_list.remove("")
76
  return ", ".join(tag_list)
 
1
  from transformers import AutoProcessor, AutoModelForCausalLM
2
  import spaces
 
3
  from PIL import Image
4
 
5
  #import subprocess
 
9
  fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @spaces.GPU
13
+ def fl_run(image):
14
  task_prompt = "<GENERATE_PROMPT>"
15
  prompt = task_prompt + "Describe this image in great detail."
16
 
 
28
  )
29
  generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
30
  parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
31
+ return parsed_answer["<GENERATE_PROMPT>Describe this image in great detail."]
 
32
 
33
 
34
  def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo: list[str]):
 
40
 
41
  if not "Use Florence-2-base-PromptGen" in algo:
42
  return input_tags
43
+ tag_list = list_uniq(to_list(input_tags) + to_list(fl_run(image) + ", "))
44
  tag_list.remove("")
45
  return ", ".join(tag_list)