Spaces:
Paused
Paused
Upload fl2basepromptgen.py
Browse files- fl2basepromptgen.py +3 -34
fl2basepromptgen.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 2 |
import spaces
|
| 3 |
-
import re
|
| 4 |
from PIL import Image
|
| 5 |
|
| 6 |
#import subprocess
|
|
@@ -10,37 +9,8 @@ fl_model = AutoModelForCausalLM.from_pretrained('MiaoshouAI/Florence-2-base-Prom
|
|
| 10 |
fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
|
| 11 |
|
| 12 |
|
| 13 |
-
def fl_modify_caption(caption: str) -> str:
|
| 14 |
-
"""
|
| 15 |
-
Removes specific prefixes from captions if present, otherwise returns the original caption.
|
| 16 |
-
Args:
|
| 17 |
-
caption (str): A string containing a caption.
|
| 18 |
-
Returns:
|
| 19 |
-
str: The caption with the prefix removed if it was present, or the original caption.
|
| 20 |
-
"""
|
| 21 |
-
# Define the prefixes to remove
|
| 22 |
-
prefix_substrings = [
|
| 23 |
-
('captured from ', ''),
|
| 24 |
-
('captured at ', '')
|
| 25 |
-
]
|
| 26 |
-
|
| 27 |
-
# Create a regex pattern to match any of the prefixes
|
| 28 |
-
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
|
| 29 |
-
replacers = {opening.lower(): replacer for opening, replacer in prefix_substrings}
|
| 30 |
-
|
| 31 |
-
# Function to replace matched prefix with its corresponding replacement
|
| 32 |
-
def replace_fn(match):
|
| 33 |
-
return replacers[match.group(0).lower()]
|
| 34 |
-
|
| 35 |
-
# Apply the regex to the caption
|
| 36 |
-
modified_caption = re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
|
| 37 |
-
|
| 38 |
-
# If the caption was modified, return the modified version; otherwise, return the original
|
| 39 |
-
return modified_caption if modified_caption != caption else caption
|
| 40 |
-
|
| 41 |
-
|
| 42 |
@spaces.GPU
|
| 43 |
-
def
|
| 44 |
task_prompt = "<GENERATE_PROMPT>"
|
| 45 |
prompt = task_prompt + "Describe this image in great detail."
|
| 46 |
|
|
@@ -58,8 +28,7 @@ def fl_run_example(image):
|
|
| 58 |
)
|
| 59 |
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 60 |
parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
|
| 61 |
-
|
| 62 |
-
return fl_modify_caption(parsed_answer["<PROMPT>"])
|
| 63 |
|
| 64 |
|
| 65 |
def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo: list[str]):
|
|
@@ -71,6 +40,6 @@ def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo:
|
|
| 71 |
|
| 72 |
if not "Use Florence-2-base-PromptGen" in algo:
|
| 73 |
return input_tags
|
| 74 |
-
tag_list = list_uniq(to_list(input_tags) + to_list(
|
| 75 |
tag_list.remove("")
|
| 76 |
return ", ".join(tag_list)
|
|
|
|
| 1 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 2 |
import spaces
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
|
| 5 |
#import subprocess
|
|
|
|
| 9 |
fl_processor = AutoProcessor.from_pretrained('MiaoshouAI/Florence-2-base-PromptGen', trust_remote_code=True)
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
@spaces.GPU
|
| 13 |
+
def fl_run(image):
|
| 14 |
task_prompt = "<GENERATE_PROMPT>"
|
| 15 |
prompt = task_prompt + "Describe this image in great detail."
|
| 16 |
|
|
|
|
| 28 |
)
|
| 29 |
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 30 |
parsed_answer = fl_processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
|
| 31 |
+
return parsed_answer["<GENERATE_PROMPT>Describe this image in great detail."]
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def predict_tags_fl2_base_prompt_gen(image: Image.Image, input_tags: str, algo: list[str]):
|
|
|
|
| 40 |
|
| 41 |
if not "Use Florence-2-base-PromptGen" in algo:
|
| 42 |
return input_tags
|
| 43 |
+
tag_list = list_uniq(to_list(input_tags) + to_list(fl_run(image) + ", "))
|
| 44 |
tag_list.remove("")
|
| 45 |
return ", ".join(tag_list)
|