Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -115,22 +115,22 @@ def load_model(check_type):
|
|
| 115 |
def process_image(model, tokenizer, transform, device, check_type, image, text):
|
| 116 |
global current_vis, current_bpe, current_index
|
| 117 |
src_size = image.size
|
| 118 |
-
# Ensure all processing is done on the correct device
|
| 119 |
-
image = image.to(device)
|
| 120 |
|
|
|
|
| 121 |
if 'TokenOCR' in check_type:
|
|
|
|
| 122 |
images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
|
| 123 |
image_size=model.config.force_image_size,
|
| 124 |
use_thumbnail=model.config.use_thumbnail,
|
| 125 |
return_ratio=True)
|
| 126 |
-
pixel_values = torch.stack([transform(img) for img in images])
|
| 127 |
else:
|
| 128 |
-
|
|
|
|
| 129 |
target_ratio = (1, 1)
|
| 130 |
|
| 131 |
text += ' '
|
| 132 |
-
input_ids = tokenizer(text)
|
| 133 |
-
input_ids = torch.tensor(input_ids, device=device)
|
| 134 |
|
| 135 |
with torch.no_grad():
|
| 136 |
if 'R50' in check_type:
|
|
@@ -147,14 +147,14 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
|
|
| 147 |
resized_size = size1 if size1 is not None else size2
|
| 148 |
|
| 149 |
attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
[tokenizer.decode([i]) for i in input_ids],
|
| 153 |
[], target_ratio, src_size)
|
| 154 |
|
| 155 |
-
current_bpe = [tokenizer.decode([i]) for i in input_ids]
|
| 156 |
current_bpe[-1] = text
|
| 157 |
-
return image
|
|
|
|
| 158 |
|
| 159 |
# 事件处理函数
|
| 160 |
def update_index(change):
|
|
|
|
| 115 |
def process_image(model, tokenizer, transform, device, check_type, image, text):
|
| 116 |
global current_vis, current_bpe, current_index
|
| 117 |
src_size = image.size
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
# Convert PIL Image to Tensor and move to the appropriate device
|
| 120 |
if 'TokenOCR' in check_type:
|
| 121 |
+
# If dynamic preprocessing is required, handle differently
|
| 122 |
images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
|
| 123 |
image_size=model.config.force_image_size,
|
| 124 |
use_thumbnail=model.config.use_thumbnail,
|
| 125 |
return_ratio=True)
|
| 126 |
+
pixel_values = torch.stack([transform(img).to(device) for img in images])
|
| 127 |
else:
|
| 128 |
+
# Standard image processing for a single image
|
| 129 |
+
pixel_values = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
|
| 130 |
target_ratio = (1, 1)
|
| 131 |
|
| 132 |
text += ' '
|
| 133 |
+
input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device) # Ensure tokens are on the same device
|
|
|
|
| 134 |
|
| 135 |
with torch.no_grad():
|
| 136 |
if 'R50' in check_type:
|
|
|
|
| 147 |
resized_size = size1 if size1 is not None else size2
|
| 148 |
|
| 149 |
attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
|
| 150 |
+
current_vis = generate_similiarity_map([image], attn_map,
|
| 151 |
+
[tokenizer.decode([i]) for i in input_ids.squeeze()],
|
|
|
|
| 152 |
[], target_ratio, src_size)
|
| 153 |
|
| 154 |
+
current_bpe = [tokenizer.decode([i]) for i in input_ids.squeeze()]
|
| 155 |
current_bpe[-1] = text
|
| 156 |
+
return image, current_vis[0], current_bpe[0]
|
| 157 |
+
|
| 158 |
|
| 159 |
# 事件处理函数
|
| 160 |
def update_index(change):
|