Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
"""Florence-2 processor: image preprocessing, tokenization, task prompts, and post-processing."""
import re
import numpy as np
import torch
def preprocess_image(image, size=768, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
"""[B, C, H, W] or [C, H, W] float [0,1] -> [B, C, size, size] normalized."""
if image.ndim == 3:
image = image.unsqueeze(0)
image = torch.nn.functional.interpolate(image, size=(size, size), mode='bicubic', align_corners=False).clamp(0, 1)
mean_t = torch.tensor(mean, device=image.device, dtype=image.dtype).view(1, 3, 1, 1)
std_t = torch.tensor(std, device=image.device, dtype=image.dtype).view(1, 3, 1, 1)
return (image - mean_t) / std_t
class BoxQuantizer:
def __init__(self, mode, bins):
self.mode = mode
self.bins = bins
def dequantize(self, boxes, size):
bins_w, bins_h = self.bins
size_w, size_h = size
xmin, ymin, xmax, ymax = boxes.split(1, dim=-1)
return torch.cat((
(xmin + 0.5) * size_w / bins_w, (ymin + 0.5) * size_h / bins_h,
(xmax + 0.5) * size_w / bins_w, (ymax + 0.5) * size_h / bins_h,
), dim=-1)
class CoordinatesQuantizer:
def __init__(self, mode, bins):
self.mode = mode
self.bins = bins
def dequantize(self, coordinates, size):
bins_w, bins_h = self.bins
size_w, size_h = size
x, y = coordinates.split(1, dim=-1)
return torch.cat(((x + 0.5) * size_w / bins_w, (y + 0.5) * size_h / bins_h), dim=-1)
class PostProcessor:
"""Regex-based parsing of Florence-2 text outputs into structured results."""
def __init__(self, tokenizer=None):
config = {
'NUM_BBOX_HEIGHT_BINS': 1000, 'NUM_BBOX_WIDTH_BINS': 1000, 'BOX_QUANTIZATION_MODE': 'floor',
'COORDINATES_HEIGHT_BINS': 1000, 'COORDINATES_WIDTH_BINS': 1000, 'COORDINATES_QUANTIZATION_MODE': 'floor',
'PARSE_TASKS': [
{'TASK_NAME': 'od', 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'},
{'TASK_NAME': 'ocr', 'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>', 'AREA_THRESHOLD': 0.00},
{'TASK_NAME': 'phrase_grounding', 'FILTER_BY_BLACK_LIST': True},
{'TASK_NAME': 'pure_text'}, {'TASK_NAME': 'description_with_bboxes'},
{'TASK_NAME': 'description_with_polygons'}, {'TASK_NAME': 'polygons'},
{'TASK_NAME': 'bboxes'}, {'TASK_NAME': 'description_with_bboxes_or_polygons'},
],
}
self.config = config
self.parse_tasks = [t['TASK_NAME'] for t in config['PARSE_TASKS']]
self.parse_tasks_configs = {t['TASK_NAME']: t for t in config['PARSE_TASKS']}
self.tokenizer = tokenizer
if tokenizer is not None:
self.all_special_tokens = set(tokenizer.all_special_tokens)
self.box_quantizer = BoxQuantizer('floor', (1000, 1000))
self.coordinates_quantizer = CoordinatesQuantizer('floor', (1000, 1000))
self.black_list_of_phrase_grounding = set()
if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding'].get('FILTER_BY_BLACK_LIST'):
self.black_list_of_phrase_grounding = {
'it', 'I', 'me', 'mine', 'you', 'your', 'yours', 'he', 'him', 'his',
'she', 'her', 'hers', 'they', 'them', 'their', 'theirs', 'one', 'oneself',
'we', 'us', 'our', 'ours', 'mine', 'yours', 'his', 'hers', 'its',
'ours', 'yours', 'theirs', 'myself', 'yourself', 'himself', 'herself',
'itself', 'ourselves', 'yourselves', 'themselves', 'this', 'that',
'these', 'those', 'who', 'whom', 'whose', 'which', 'what', 'that',
'all', 'another', 'any', 'anybody', 'anyone', 'anything',
'each', 'everybody', 'everyone', 'everything',
'few', 'many', 'nobody', 'none', 'one', 'several',
'some', 'somebody', 'someone', 'something',
'each other', 'one another',
'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
'other objects', 'lots', 'a set',
}
def parse_od_from_text_and_spans(self, text, pattern, image_size, phrase_centric=False):
parsed = list(re.finditer(pattern, text))
instances = []
for m in parsed:
if phrase_centric:
bbox_bins = [int(m.group(j)) for j in range(2, 6)]
cat_name = m.group(1).lower().strip()
else:
bbox_bins = [int(m.group(j)) for j in range(1, 5)]
cat_name = m.group(5).lower().strip()
instances.append({
'bbox': self.box_quantizer.dequantize(torch.tensor(bbox_bins), image_size).tolist(),
'cat_name': cat_name,
})
return instances
def parse_ocr_from_text_and_spans(self, text, pattern, image_size, area_threshold=-1.0):
text = text.replace('<s>', '')
parsed = re.findall(pattern, text)
instances = []
image_width, image_height = image_size
for ocr_line in parsed:
ocr_content = ocr_line[0]
quad_box = [int(i) for i in ocr_line[1:]]
quad_box = self.coordinates_quantizer.dequantize(torch.tensor(np.array(quad_box).reshape(-1, 2)), image_size).reshape(-1).tolist()
if area_threshold > 0:
x_coords, y_coords = quad_box[0::2], quad_box[1::2]
area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(3)))
if area < (image_width * image_height) * area_threshold:
continue
instances.append({'quad_box': quad_box, 'text': ocr_content})
return instances
def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
text = text.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
pattern = r"([^<]+(?:<loc_\d+>){4,})"
phrases = re.findall(pattern, text)
phrase_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
instances = []
for pharse_text in phrases:
phrase_text_strip = pharse_text.replace('<ground>', '', 1).replace('<obj>', '', 1)
if phrase_text_strip == '':
continue
phrase = re.search(phrase_pattern, phrase_text_strip)
if phrase is None:
continue
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
if not bboxes_parsed:
continue
phrase = phrase.group().strip()
if phrase in self.black_list_of_phrase_grounding:
continue
bbox_bins = [[int(b.group(j)) for j in range(1, 5)] for b in bboxes_parsed]
phrase = phrase.encode('ascii', errors='ignore').decode('ascii')
instances.append({
'bbox': self.box_quantizer.dequantize(torch.tensor(bbox_bins), image_size).tolist(),
'cat_name': phrase,
})
return instances
def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
text = text.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
if allow_empty_phrase:
pattern = r"(?:(?:<loc_\d+>){{4,}})"
else:
pattern = r"([^<]+(?:<loc_\d+>){4,})"
phrases = re.findall(pattern, text)
phrase_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
instances = []
for pharse_text in phrases:
phrase_text_strip = pharse_text.replace('<ground>', '', 1).replace('<obj>', '', 1)
if phrase_text_strip == '' and not allow_empty_phrase:
continue
phrase = re.search(phrase_pattern, phrase_text_strip)
if phrase is None:
continue
phrase = phrase.group().strip()
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
if not bboxes_parsed:
continue
bbox_bins = [[int(b.group(j)) for j in range(1, 5)] for b in bboxes_parsed]
bboxes = self.box_quantizer.dequantize(torch.tensor(bbox_bins), image_size).tolist()
phrase = phrase.encode('ascii', errors='ignore').decode('ascii')
for bbox in bboxes:
instances.append({'bbox': bbox, 'cat_name': phrase})
return instances
def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
allow_empty_phrase=False, polygon_sep_token='<sep>',
polygon_start_token='<poly>', polygon_end_token='</poly>',
with_box_at_start=False):
text = text.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
sep, start, end = re.escape(polygon_sep_token), re.escape(polygon_start_token), re.escape(polygon_end_token)
if allow_empty_phrase:
pattern = rf"(?:(?:<loc_\d+>|{sep}|{start}|{end}){{4,}})"
else:
pattern = rf"([^<]+(?:<loc_\d+>|{sep}|{start}|{end}){{4,}})"
phrases = re.findall(pattern, text)
phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
box_pattern = rf'((?:<loc_\d+>)+)(?:{sep}|$)'
polygons_instance_pattern = rf'{start}(.*?){end}'
instances = []
for phrase_text in phrases:
phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
if phrase_text_strip == '' and not allow_empty_phrase:
continue
phrase = re.search(phrase_string_pattern, phrase_text_strip)
if phrase is None:
continue
phrase = phrase.group().strip()
if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
poly_instances = list(re.finditer(polygons_instance_pattern, phrase_text))
else:
poly_instances = [phrase_text]
for pi in poly_instances:
poly_text = pi.group(1) if not isinstance(pi, str) else pi
polygons_parsed = list(re.finditer(box_pattern, poly_text))
if not polygons_parsed:
continue
bbox, polygons = [], []
for pp in polygons_parsed:
coords = [int(m.group(1)) for m in re.finditer(r'<loc_(\d+)>', pp.group(1))]
if with_box_at_start and not bbox:
if len(coords) > 4:
bbox = coords[:4]
coords = coords[4:]
else:
bbox = [0, 0, 0, 0]
if len(coords) % 2 == 1:
coords = coords[:-1]
polygons.append(self.coordinates_quantizer.dequantize(
torch.tensor(np.array(coords).reshape(-1, 2)), image_size,
).reshape(-1).tolist())
instance = {'cat_name': phrase, 'polygons': polygons}
if bbox:
instance['bbox'] = self.box_quantizer.dequantize(torch.tensor([bbox]), image_size).tolist()[0]
instances.append(instance)
return instances
def __call__(self, text=None, image_size=None, parse_tasks=None):
if parse_tasks is not None:
if isinstance(parse_tasks, str):
parse_tasks = [parse_tasks]
for t in parse_tasks:
assert t in self.parse_tasks, f'parse task {t} not supported'
assert text is not None, 'text should be provided'
parsed_dict = {'text': text}
for task in self.parse_tasks:
if parse_tasks is not None and task not in parse_tasks:
continue
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
if task == 'ocr':
parsed_dict['ocr'] = self.parse_ocr_from_text_and_spans(text, pattern, image_size, self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.0))
elif task == 'phrase_grounding':
parsed_dict['phrase_grounding'] = self.parse_phrase_grounding_from_text_and_spans(text, pattern, image_size)
elif task == 'pure_text':
parsed_dict['pure_text'] = text
elif task == 'description_with_bboxes':
parsed_dict['description_with_bboxes'] = self.parse_description_with_bboxes_from_text_and_spans(text, pattern, image_size)
elif task == 'description_with_polygons':
parsed_dict['description_with_polygons'] = self.parse_description_with_polygons_from_text_and_spans(text, pattern, image_size)
elif task == 'polygons':
parsed_dict['polygons'] = self.parse_description_with_polygons_from_text_and_spans(text, pattern, image_size, allow_empty_phrase=True)
elif task == 'bboxes':
parsed_dict['bboxes'] = self.parse_description_with_bboxes_from_text_and_spans(text, pattern, image_size, allow_empty_phrase=True)
elif task == 'description_with_bboxes_or_polygons':
if '<poly>' in text:
parsed_dict['description_with_bboxes_or_polygons'] = self.parse_description_with_polygons_from_text_and_spans(text, pattern, image_size)
else:
parsed_dict['description_with_bboxes_or_polygons'] = self.parse_description_with_bboxes_from_text_and_spans(text, pattern, image_size)
else:
raise ValueError(f"task {task} is not supported")
return parsed_dict
class Processor:
def __init__(self, model_path):
from .tokenizer import Florence2Tokenizer
self.tokenizer = Florence2Tokenizer(model_path)
self.image_seq_length = 577
self.tasks_answer_post_processing_type = {
'<OCR>': 'pure_text', '<OCR_WITH_REGION>': 'ocr',
'<CAPTION>': 'pure_text', '<DETAILED_CAPTION>': 'pure_text', '<MORE_DETAILED_CAPTION>': 'pure_text',
'<OD>': 'description_with_bboxes', '<DENSE_REGION_CAPTION>': 'description_with_bboxes',
'<CAPTION_TO_PHRASE_GROUNDING>': 'phrase_grounding',
'<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons', '<REGION_TO_SEGMENTATION>': 'polygons',
'<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
'<REGION_TO_CATEGORY>': 'pure_text', '<REGION_TO_DESCRIPTION>': 'pure_text', '<REGION_TO_OCR>': 'pure_text',
'<REGION_PROPOSAL>': 'bboxes',
}
self.task_prompts_without_inputs = {
'<OCR>': 'What is the text in the image?',
'<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
'<CAPTION>': 'What does the image describe?',
'<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
'<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
'<OD>': 'Locate the objects with category name in the image.',
'<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
'<REGION_PROPOSAL>': 'Locate the region proposals in the image.',
}
self.task_prompts_with_input = {
'<CAPTION_TO_PHRASE_GROUNDING>': 'Locate the phrases in the caption: {input}',
'<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
'<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
'<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
'<REGION_TO_CATEGORY>': 'What is the region {input}?',
'<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
'<REGION_TO_OCR>': 'What text is in the region {input}?',
}
self.post_processor = PostProcessor(tokenizer=self.tokenizer)
def _construct_prompts(self, text):
for task_token, task_prompt in self.task_prompts_without_inputs.items():
if task_token in text:
return task_prompt
for task_token, task_prompt in self.task_prompts_with_input.items():
if task_token in text:
return task_prompt.format(input=text.replace(task_token, ''))
return text
def __call__(self, text, images):
prompt = self._construct_prompts(text)
encoded = self.tokenizer.encode(prompt)
pixel_values = preprocess_image(images)
return {'input_ids': encoded['input_ids'], 'pixel_values': pixel_values}
def batch_decode(self, token_ids, skip_special_tokens=False):
return self.tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
def post_process_generation(self, text, task, image_size):
pp_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
result = self.post_processor(text=text, image_size=image_size, parse_tasks=pp_type)[pp_type]
if pp_type == 'pure_text':
final = result.replace('<s>', '').replace('</s>', '')
elif pp_type in ('od', 'description_with_bboxes', 'bboxes'):
final = {'bboxes': [i['bbox'] for i in result], 'labels': [str(i['cat_name']) for i in result]}
elif pp_type == 'ocr':
final = {'quad_boxes': [i['quad_box'] for i in result], 'labels': [str(i['text']) for i in result]}
elif pp_type == 'phrase_grounding':
bboxes, labels = [], []
for phrase in result:
for bbox in phrase['bbox']:
bboxes.append(bbox)
labels.append(phrase['cat_name'])
final = {'bboxes': bboxes, 'labels': labels}
elif pp_type in ('description_with_polygons', 'polygons'):
final = {'polygons': [r['polygons'] for r in result], 'labels': [r['cat_name'] for r in result]}
elif pp_type == 'description_with_bboxes_or_polygons':
bboxes, bl, polygons, pl = [], [], [], []
for r in result:
if 'polygons' in r:
polygons.append(r['polygons'])
pl.append(r['cat_name'])
else:
bboxes.append(r['bbox'])
bl.append(r['cat_name'])
final = {'bboxes': bboxes, 'bboxes_labels': bl, 'polygons': polygons, 'polygons_labels': pl}
else:
raise ValueError(f'Unknown post processing type: {pp_type}')
return {task: final}