|
|
""" |
|
|
Visualize predictions from trained model on local images folder |
|
|
""" |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import torchvision.transforms as transforms |
|
|
from split_model import SplitModel |
|
|
from PIL import Image, ImageDraw |
|
|
import numpy as np |
|
|
import glob |
|
|
import os |
|
|
|
|
|
class LocalImageDataset(Dataset): |
|
|
"""Dataset for loading images from a local folder""" |
|
|
def __init__(self, image_folder): |
|
|
self.image_paths = sorted(glob.glob(os.path.join(image_folder, "*.png")) + |
|
|
glob.glob(os.path.join(image_folder, "*.jpg")) + |
|
|
glob.glob(os.path.join(image_folder, "*.jpeg"))) |
|
|
|
|
|
if len(self.image_paths) == 0: |
|
|
raise ValueError(f"No images found in {image_folder}") |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((960, 960)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
print(f"Found {len(self.image_paths)} images in {image_folder}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.image_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image_path = self.image_paths[idx] |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image_transformed = self.transform(image) |
|
|
return image_transformed, image, image_path |
|
|
|
|
|
def get_middle_of_groups(binary_array): |
|
|
""" |
|
|
Find groups of consecutive 1's and return only the middle index of each group. |
|
|
Example: [0,0,1,1,1,1,1,0,0,1,1,0] -> [0,0,0,0,1,0,0,0,1,0,0] |
|
|
""" |
|
|
result = np.zeros_like(binary_array) |
|
|
i = 0 |
|
|
n = len(binary_array) |
|
|
|
|
|
while i < n: |
|
|
if binary_array[i] == 1: |
|
|
|
|
|
start = i |
|
|
while i < n and binary_array[i] == 1: |
|
|
i += 1 |
|
|
end = i - 1 |
|
|
|
|
|
|
|
|
middle = (start + end) // 2 |
|
|
result[middle] = 1 |
|
|
else: |
|
|
i += 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def visualize_prediction(model, dataset, idx, device, output_folder, threshold=0.5): |
|
|
"""Visualize prediction for a single image (no ground truth)""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
image_tensor, original_image, image_path = dataset[idx] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
image_batch = image_tensor.unsqueeze(0).to(device) |
|
|
h_pred, v_pred = model(image_batch) |
|
|
|
|
|
|
|
|
h_pred = h_pred.repeat_interleave(2, dim=1) |
|
|
v_pred = v_pred.repeat_interleave(2, dim=1) |
|
|
|
|
|
h_pred = h_pred.squeeze(0).cpu() |
|
|
v_pred = v_pred.squeeze(0).cpu() |
|
|
|
|
|
|
|
|
h_binary = (h_pred > threshold).float().numpy() |
|
|
v_binary = (v_pred > threshold).float().numpy() |
|
|
|
|
|
|
|
|
h_binary_clean = get_middle_of_groups(h_binary) |
|
|
v_binary_clean = get_middle_of_groups(v_binary) |
|
|
|
|
|
|
|
|
h_splits = h_binary_clean.sum() |
|
|
v_splits = v_binary_clean.sum() |
|
|
pred_rows = int(h_splits) + 1 |
|
|
pred_cols = int(v_splits) + 1 |
|
|
|
|
|
print(f"\nImage {idx}: {os.path.basename(image_path)}") |
|
|
print(f" H Splits: {h_splits:.0f} | Pred Rows: {pred_rows}") |
|
|
print(f" V Splits: {v_splits:.0f} | Pred Cols: {pred_cols}") |
|
|
print(f" Total Cells: {pred_rows * pred_cols}") |
|
|
print(f" H pred range: [{h_pred.min():.3f}, {h_pred.max():.3f}]") |
|
|
print(f" V pred range: [{v_pred.min():.3f}, {v_pred.max():.3f}]") |
|
|
|
|
|
|
|
|
W, H = original_image.size |
|
|
|
|
|
|
|
|
zoom_factor = 1.5 |
|
|
W_zoomed = int(W * zoom_factor) |
|
|
H_zoomed = int(H * zoom_factor) |
|
|
|
|
|
|
|
|
original_zoomed = original_image.resize((W_zoomed, H_zoomed), Image.LANCZOS) |
|
|
|
|
|
|
|
|
info_height = 150 |
|
|
label_height = 60 |
|
|
padding = 40 |
|
|
|
|
|
try: |
|
|
from PIL import ImageFont |
|
|
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32) |
|
|
font_text = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24) |
|
|
font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28) |
|
|
except: |
|
|
font_title = None |
|
|
font_text = None |
|
|
font_label = None |
|
|
|
|
|
|
|
|
total_width = W_zoomed + padding * 2 |
|
|
total_height = info_height + (label_height + H_zoomed) * 3 + padding * 5 |
|
|
vis_image = Image.new('RGB', (total_width, total_height), 'white') |
|
|
draw = ImageDraw.Draw(vis_image) |
|
|
|
|
|
|
|
|
info_x = padding |
|
|
info_y = padding |
|
|
|
|
|
draw.rectangle([info_x, info_y, total_width - padding, info_y + info_height], |
|
|
fill='#e3f2fd', outline='#1565c0', width=4) |
|
|
|
|
|
draw.text((info_x + 20, info_y + 20), "PREDICTED SPLITS", fill='#0d47a1', font=font_title) |
|
|
draw.text((info_x + 20, info_y + 70), f"Rows: {pred_rows}", fill='black', font=font_text) |
|
|
draw.text((info_x + 20, info_y + 105), f"Columns: {pred_cols}", fill='black', font=font_text) |
|
|
draw.text((info_x + 300, info_y + 70), f"H Splits: {int(h_splits)}", fill='#c62828', font=font_text) |
|
|
draw.text((info_x + 300, info_y + 105), f"V Splits: {int(v_splits)}", fill='#1565c0', font=font_text) |
|
|
|
|
|
|
|
|
y_pos = info_height + padding * 2 |
|
|
draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height], |
|
|
fill='#f5f5f5', outline='#666666', width=2) |
|
|
draw.text((padding + 20, y_pos + 15), "Original Image", fill='#333333', font=font_label) |
|
|
|
|
|
vis_image.paste(original_zoomed, (padding, y_pos + label_height)) |
|
|
|
|
|
|
|
|
y_pos = info_height + padding * 3 + label_height + H_zoomed |
|
|
draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height], |
|
|
fill='#fff3e0', outline='#ff9800', width=2) |
|
|
draw.text((padding + 20, y_pos + 15), "Raw Model Predictions (All 1's)", fill='#e65100', font=font_label) |
|
|
|
|
|
|
|
|
raw_pred_image = original_image.copy() |
|
|
draw_raw = ImageDraw.Draw(raw_pred_image) |
|
|
|
|
|
|
|
|
for y in range(960): |
|
|
if h_binary[y] == 1: |
|
|
y_scaled = int(y * H / 960) |
|
|
draw_raw.line([(0, y_scaled), (W, y_scaled)], fill='#ff0000', width=2) |
|
|
|
|
|
|
|
|
for x in range(960): |
|
|
if v_binary[x] == 1: |
|
|
x_scaled = int(x * W / 960) |
|
|
draw_raw.line([(x_scaled, 0), (x_scaled, H)], fill='#0000ff', width=2) |
|
|
|
|
|
|
|
|
raw_pred_zoomed = raw_pred_image.resize((W_zoomed, H_zoomed), Image.LANCZOS) |
|
|
vis_image.paste(raw_pred_zoomed, (padding, y_pos + label_height)) |
|
|
|
|
|
|
|
|
y_pos = info_height + padding * 4 + (label_height + H_zoomed) * 2 |
|
|
draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height], |
|
|
fill='#e3f2fd', outline='#1565c0', width=2) |
|
|
draw.text((padding + 20, y_pos + 15), "Cleaned Predictions (Middle Only)", fill='#0d47a1', font=font_label) |
|
|
|
|
|
|
|
|
pred_image = original_image.copy() |
|
|
draw_pred = ImageDraw.Draw(pred_image) |
|
|
|
|
|
|
|
|
for y in range(960): |
|
|
if h_binary_clean[y] == 1: |
|
|
y_scaled = int(y * H / 960) |
|
|
draw_pred.line([(0, y_scaled), (W, y_scaled)], fill='#ff0000', width=3) |
|
|
|
|
|
|
|
|
for x in range(960): |
|
|
if v_binary_clean[x] == 1: |
|
|
x_scaled = int(x * W / 960) |
|
|
draw_pred.line([(x_scaled, 0), (x_scaled, H)], fill='#0000ff', width=3) |
|
|
|
|
|
|
|
|
pred_zoomed = pred_image.resize((W_zoomed, H_zoomed), Image.LANCZOS) |
|
|
vis_image.paste(pred_zoomed, (padding, y_pos + label_height)) |
|
|
|
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(image_path))[0] |
|
|
output_path = os.path.join(output_folder, f'prediction_{base_name}.png') |
|
|
vis_image.save(output_path) |
|
|
print(f" Saved: {output_path}") |
|
|
|
|
|
return pred_rows, pred_cols |
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser(description='Visualize table split predictions on local images') |
|
|
parser.add_argument('--image-folder', type=str, required=True, |
|
|
help='Path to folder containing images') |
|
|
parser.add_argument('--output-folder', type=str, default='predictions_output', |
|
|
help='Path to folder for saving predictions (default: predictions_output)') |
|
|
parser.add_argument('--model-path', type=str, default=None, |
|
|
help='Path to trained model checkpoint (if not specified, will search common locations)') |
|
|
parser.add_argument('--threshold', type=float, default=0.5, |
|
|
help='Threshold for binary predictions (default: 0.5)') |
|
|
parser.add_argument('--num-images', type=int, default=-1, |
|
|
help='Number of images to process (-1 for all)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Device: {device}") |
|
|
|
|
|
|
|
|
os.makedirs(args.output_folder, exist_ok=True) |
|
|
print(f"Output folder: {args.output_folder}") |
|
|
|
|
|
|
|
|
print(f"\nLoading images from: {args.image_folder}") |
|
|
dataset = LocalImageDataset(args.image_folder) |
|
|
|
|
|
|
|
|
print("\nLoading model...") |
|
|
model = SplitModel().to(device) |
|
|
|
|
|
|
|
|
if args.model_path: |
|
|
checkpoint_path = args.model_path |
|
|
else: |
|
|
possible_paths = [ |
|
|
'best_split_model.pth', |
|
|
'/home/ng6309/datascience/santhosh/experiments/tablet/best_split_model.pth', |
|
|
'/home/ng6309/datascience/santhosh/experiments/tablet/runs/tablet_split_20251006_214335/best_split_model.pth', |
|
|
] |
|
|
|
|
|
checkpoint_path = None |
|
|
for path in possible_paths: |
|
|
if os.path.exists(path): |
|
|
checkpoint_path = path |
|
|
break |
|
|
|
|
|
if checkpoint_path is None or not os.path.exists(checkpoint_path): |
|
|
print("ERROR: No trained model found! Please specify --model-path or ensure model exists in default locations") |
|
|
return |
|
|
|
|
|
print(f"Loading checkpoint from: {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
print(f"\nModel trained for {checkpoint['epoch']} epochs") |
|
|
print(f"Val loss: {checkpoint['val_loss']:.4f}") |
|
|
if 'val_h_f1' in checkpoint: |
|
|
print(f"Val H F1: {checkpoint['val_h_f1']:.3f}") |
|
|
print(f"Val V F1: {checkpoint['val_v_f1']:.3f}") |
|
|
|
|
|
|
|
|
num_samples = len(dataset) if args.num_images == -1 else min(args.num_images, len(dataset)) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print(f"Visualizing predictions for {num_samples} images") |
|
|
print("="*60) |
|
|
|
|
|
for idx in range(num_samples): |
|
|
pred_rows, pred_cols = visualize_prediction(model, dataset, idx, device, args.output_folder, threshold=args.threshold) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print(f"Completed processing {num_samples} images") |
|
|
print(f"All predictions saved to: {args.output_folder}") |
|
|
print("="*60) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|