Spaces:
Runtime error
Runtime error
Commit ·
2915058
1
Parent(s): 7883098
debug
Browse files- app.py +9 -9
- data/mm_data/ocr_dataset.py +10 -4
app.py
CHANGED
|
@@ -70,7 +70,7 @@ def get_images(img: str, reader: ReaderLite, **kwargs):
|
|
| 70 |
return results
|
| 71 |
|
| 72 |
|
| 73 |
-
def draw_boxes(image, bounds, color='red', width=
|
| 74 |
draw = ImageDraw.Draw(image)
|
| 75 |
for i, bound in enumerate(bounds):
|
| 76 |
p0, p1, p2, p3 = bound
|
|
@@ -102,7 +102,7 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
|
|
| 102 |
_patch_resize_transform = transforms.Compose(
|
| 103 |
[
|
| 104 |
lambda image: ocr_resize(
|
| 105 |
-
image, patch_image_size, is_document=is_document
|
| 106 |
),
|
| 107 |
transforms.ToTensor(),
|
| 108 |
transforms.Normalize(mean=mean, std=std),
|
|
@@ -113,7 +113,7 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
|
|
| 113 |
|
| 114 |
|
| 115 |
reader = ReaderLite()
|
| 116 |
-
overrides={"eval_cider": False, "beam":
|
| 117 |
"orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 7}
|
| 118 |
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
| 119 |
utils.split_paths('checkpoints/ocr_general_clean.pt'),
|
|
@@ -163,9 +163,9 @@ def apply_half(t):
|
|
| 163 |
return t
|
| 164 |
|
| 165 |
|
| 166 |
-
def ocr(
|
| 167 |
-
out_img = Image.open(
|
| 168 |
-
results = get_images(
|
| 169 |
box_list, image_list = zip(*results)
|
| 170 |
draw_boxes(out_img, box_list)
|
| 171 |
|
|
@@ -191,9 +191,9 @@ description = "Gradio Demo for OFA-OCR. Upload your own image or click any one o
|
|
| 191 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
| 192 |
"Repo</a></p> "
|
| 193 |
examples = [['lihe.png']]
|
| 194 |
-
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath'),
|
| 195 |
-
outputs=[gr.outputs.Image(type='pil'), gr.outputs.Textbox(label="OCR result")],
|
| 196 |
title=title, description=description, article=article, examples=examples,
|
| 197 |
-
allow_flagging=
|
| 198 |
io.launch(cache_examples=True)
|
| 199 |
|
|
|
|
| 70 |
return results
|
| 71 |
|
| 72 |
|
| 73 |
+
def draw_boxes(image, bounds, color='red', width=10):
|
| 74 |
draw = ImageDraw.Draw(image)
|
| 75 |
for i, bound in enumerate(bounds):
|
| 76 |
p0, p1, p2, p3 = bound
|
|
|
|
| 102 |
_patch_resize_transform = transforms.Compose(
|
| 103 |
[
|
| 104 |
lambda image: ocr_resize(
|
| 105 |
+
image, patch_image_size, is_document=is_document, split='test',
|
| 106 |
),
|
| 107 |
transforms.ToTensor(),
|
| 108 |
transforms.Normalize(mean=mean, std=std),
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
reader = ReaderLite()
|
| 116 |
+
overrides={"eval_cider": False, "beam": 4, "max_len_b": 32, "patch_image_size": 480,
|
| 117 |
"orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 7}
|
| 118 |
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
| 119 |
utils.split_paths('checkpoints/ocr_general_clean.pt'),
|
|
|
|
| 163 |
return t
|
| 164 |
|
| 165 |
|
| 166 |
+
def ocr(Image):
|
| 167 |
+
out_img = Image.open(Image)
|
| 168 |
+
results = get_images(Image, reader, link_threshold=0.2)
|
| 169 |
box_list, image_list = zip(*results)
|
| 170 |
draw_boxes(out_img, box_list)
|
| 171 |
|
|
|
|
| 191 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
| 192 |
"Repo</a></p> "
|
| 193 |
examples = [['lihe.png']]
|
| 194 |
+
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
|
| 195 |
+
outputs=[gr.outputs.Image(type='pil', label='Image'), gr.outputs.Textbox(label="OCR result")],
|
| 196 |
title=title, description=description, article=article, examples=examples,
|
| 197 |
+
allow_flagging='never', allow_screenshot=False)
|
| 198 |
io.launch(cache_examples=True)
|
| 199 |
|
data/mm_data/ocr_dataset.py
CHANGED
|
@@ -82,7 +82,7 @@ def collate(samples, pad_idx, eos_idx):
|
|
| 82 |
return batch
|
| 83 |
|
| 84 |
|
| 85 |
-
def ocr_resize(img, patch_image_size, is_document=False):
|
| 86 |
img = img.convert("RGB")
|
| 87 |
width, height = img.size
|
| 88 |
|
|
@@ -92,13 +92,19 @@ def ocr_resize(img, patch_image_size, is_document=False):
|
|
| 92 |
if width >= height:
|
| 93 |
new_width = max(64, patch_image_size)
|
| 94 |
new_height = max(64, int(patch_image_size * (height / width)))
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
bottom = patch_image_size - new_height - top
|
| 97 |
left, right = 0, 0
|
| 98 |
else:
|
| 99 |
new_height = max(64, patch_image_size)
|
| 100 |
new_width = max(64, int(patch_image_size * (width / height)))
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
right = patch_image_size - new_width - left
|
| 103 |
top, bottom = 0, 0
|
| 104 |
|
|
@@ -151,7 +157,7 @@ class OcrDataset(OFADataset):
|
|
| 151 |
self.patch_resize_transform = transforms.Compose(
|
| 152 |
[
|
| 153 |
lambda image: ocr_resize(
|
| 154 |
-
image, patch_image_size, is_document=is_document
|
| 155 |
),
|
| 156 |
transforms.ToTensor(),
|
| 157 |
transforms.Normalize(mean=mean, std=std),
|
|
|
|
| 82 |
return batch
|
| 83 |
|
| 84 |
|
| 85 |
+
def ocr_resize(img, patch_image_size, is_document=False, split='train'):
|
| 86 |
img = img.convert("RGB")
|
| 87 |
width, height = img.size
|
| 88 |
|
|
|
|
| 92 |
if width >= height:
|
| 93 |
new_width = max(64, patch_image_size)
|
| 94 |
new_height = max(64, int(patch_image_size * (height / width)))
|
| 95 |
+
if split != 'train':
|
| 96 |
+
top = int((patch_image_size - new_height) // 2)
|
| 97 |
+
else:
|
| 98 |
+
top = random.randint(0, patch_image_size - new_height)
|
| 99 |
bottom = patch_image_size - new_height - top
|
| 100 |
left, right = 0, 0
|
| 101 |
else:
|
| 102 |
new_height = max(64, patch_image_size)
|
| 103 |
new_width = max(64, int(patch_image_size * (width / height)))
|
| 104 |
+
if split != 'train':
|
| 105 |
+
left = int((patch_image_size - new_width) // 2)
|
| 106 |
+
else:
|
| 107 |
+
left = random.randint(0, patch_image_size - new_width)
|
| 108 |
right = patch_image_size - new_width - left
|
| 109 |
top, bottom = 0, 0
|
| 110 |
|
|
|
|
| 157 |
self.patch_resize_transform = transforms.Compose(
|
| 158 |
[
|
| 159 |
lambda image: ocr_resize(
|
| 160 |
+
image, patch_image_size, is_document=is_document, split=split,
|
| 161 |
),
|
| 162 |
transforms.ToTensor(),
|
| 163 |
transforms.Normalize(mean=mean, std=std),
|