tablet-split-model / test_split_by_images_folder.py
santhoshkammari's picture
Upload test_split_by_images_folder.py with huggingface_hub
226b8b6 verified
"""
Visualize predictions from trained model on local images folder
"""
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from split_model import SplitModel
from PIL import Image, ImageDraw
import numpy as np
import glob
import os
class LocalImageDataset(Dataset):
"""Dataset for loading images from a local folder"""
def __init__(self, image_folder):
self.image_paths = sorted(glob.glob(os.path.join(image_folder, "*.png")) +
glob.glob(os.path.join(image_folder, "*.jpg")) +
glob.glob(os.path.join(image_folder, "*.jpeg")))
if len(self.image_paths) == 0:
raise ValueError(f"No images found in {image_folder}")
self.transform = transforms.Compose([
transforms.Resize((960, 960)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print(f"Found {len(self.image_paths)} images in {image_folder}")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert('RGB')
image_transformed = self.transform(image)
return image_transformed, image, image_path
def get_middle_of_groups(binary_array):
"""
Find groups of consecutive 1's and return only the middle index of each group.
Example: [0,0,1,1,1,1,1,0,0,1,1,0] -> [0,0,0,0,1,0,0,0,1,0,0]
"""
result = np.zeros_like(binary_array)
i = 0
n = len(binary_array)
while i < n:
if binary_array[i] == 1:
# Found start of a group
start = i
while i < n and binary_array[i] == 1:
i += 1
end = i - 1
# Get middle index
middle = (start + end) // 2
result[middle] = 1
else:
i += 1
return result
def visualize_prediction(model, dataset, idx, device, output_folder, threshold=0.5):
"""Visualize prediction for a single image (no ground truth)"""
model.eval()
# Get sample
image_tensor, original_image, image_path = dataset[idx]
# Predict
with torch.no_grad():
image_batch = image_tensor.unsqueeze(0).to(device)
h_pred, v_pred = model(image_batch) # [1, 480]
# Upsample to 960 for visualization
h_pred = h_pred.repeat_interleave(2, dim=1) # [1, 960]
v_pred = v_pred.repeat_interleave(2, dim=1) # [1, 960]
h_pred = h_pred.squeeze(0).cpu()
v_pred = v_pred.squeeze(0).cpu()
# Apply threshold
h_binary = (h_pred > threshold).float().numpy()
v_binary = (v_pred > threshold).float().numpy()
# Get only middle of grouped 1's for cleaner visualization
h_binary_clean = get_middle_of_groups(h_binary)
v_binary_clean = get_middle_of_groups(v_binary)
# Count predictions (use cleaned version)
h_splits = h_binary_clean.sum()
v_splits = v_binary_clean.sum()
pred_rows = int(h_splits) + 1
pred_cols = int(v_splits) + 1
print(f"\nImage {idx}: {os.path.basename(image_path)}")
print(f" H Splits: {h_splits:.0f} | Pred Rows: {pred_rows}")
print(f" V Splits: {v_splits:.0f} | Pred Cols: {pred_cols}")
print(f" Total Cells: {pred_rows * pred_cols}")
print(f" H pred range: [{h_pred.min():.3f}, {h_pred.max():.3f}]")
print(f" V pred range: [{v_pred.min():.3f}, {v_pred.max():.3f}]")
# Visualize
W, H = original_image.size
# Zoom factor for larger images
zoom_factor = 1.5
W_zoomed = int(W * zoom_factor)
H_zoomed = int(H * zoom_factor)
# Resize images for better visibility
original_zoomed = original_image.resize((W_zoomed, H_zoomed), Image.LANCZOS)
# Info panel dimensions
info_height = 150
label_height = 60
padding = 40
try:
from PIL import ImageFont
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 32)
font_text = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24)
font_label = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28)
except:
font_title = None
font_text = None
font_label = None
# Create visualization with 3 images stacked vertically
total_width = W_zoomed + padding * 2
total_height = info_height + (label_height + H_zoomed) * 3 + padding * 5
vis_image = Image.new('RGB', (total_width, total_height), 'white')
draw = ImageDraw.Draw(vis_image)
# Info Panel at top
info_x = padding
info_y = padding
draw.rectangle([info_x, info_y, total_width - padding, info_y + info_height],
fill='#e3f2fd', outline='#1565c0', width=4)
draw.text((info_x + 20, info_y + 20), "PREDICTED SPLITS", fill='#0d47a1', font=font_title)
draw.text((info_x + 20, info_y + 70), f"Rows: {pred_rows}", fill='black', font=font_text)
draw.text((info_x + 20, info_y + 105), f"Columns: {pred_cols}", fill='black', font=font_text)
draw.text((info_x + 300, info_y + 70), f"H Splits: {int(h_splits)}", fill='#c62828', font=font_text)
draw.text((info_x + 300, info_y + 105), f"V Splits: {int(v_splits)}", fill='#1565c0', font=font_text)
# 1. Original image (top)
y_pos = info_height + padding * 2
draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height],
fill='#f5f5f5', outline='#666666', width=2)
draw.text((padding + 20, y_pos + 15), "Original Image", fill='#333333', font=font_label)
vis_image.paste(original_zoomed, (padding, y_pos + label_height))
# 2. Raw predictions (middle) - with all thick lines
y_pos = info_height + padding * 3 + label_height + H_zoomed
draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height],
fill='#fff3e0', outline='#ff9800', width=2)
draw.text((padding + 20, y_pos + 15), "Raw Model Predictions (All 1's)", fill='#e65100', font=font_label)
# Create raw prediction image with all 1's (before cleaning)
raw_pred_image = original_image.copy()
draw_raw = ImageDraw.Draw(raw_pred_image)
# Draw ALL predicted horizontal lines (red) - raw, not cleaned
for y in range(960):
if h_binary[y] == 1:
y_scaled = int(y * H / 960)
draw_raw.line([(0, y_scaled), (W, y_scaled)], fill='#ff0000', width=2)
# Draw ALL predicted vertical lines (blue) - raw, not cleaned
for x in range(960):
if v_binary[x] == 1:
x_scaled = int(x * W / 960)
draw_raw.line([(x_scaled, 0), (x_scaled, H)], fill='#0000ff', width=2)
# Zoom raw prediction image
raw_pred_zoomed = raw_pred_image.resize((W_zoomed, H_zoomed), Image.LANCZOS)
vis_image.paste(raw_pred_zoomed, (padding, y_pos + label_height))
# 3. Cleaned predictions (bottom) - only middle lines
y_pos = info_height + padding * 4 + (label_height + H_zoomed) * 2
draw.rectangle([padding, y_pos, total_width - padding, y_pos + label_height],
fill='#e3f2fd', outline='#1565c0', width=2)
draw.text((padding + 20, y_pos + 15), "Cleaned Predictions (Middle Only)", fill='#0d47a1', font=font_label)
# Create cleaned prediction image
pred_image = original_image.copy()
draw_pred = ImageDraw.Draw(pred_image)
# Draw predicted horizontal lines (red) - using cleaned version
for y in range(960):
if h_binary_clean[y] == 1:
y_scaled = int(y * H / 960)
draw_pred.line([(0, y_scaled), (W, y_scaled)], fill='#ff0000', width=3)
# Draw predicted vertical lines (blue) - using cleaned version
for x in range(960):
if v_binary_clean[x] == 1:
x_scaled = int(x * W / 960)
draw_pred.line([(x_scaled, 0), (x_scaled, H)], fill='#0000ff', width=3)
# Zoom prediction image
pred_zoomed = pred_image.resize((W_zoomed, H_zoomed), Image.LANCZOS)
vis_image.paste(pred_zoomed, (padding, y_pos + label_height))
# Save to output folder
base_name = os.path.splitext(os.path.basename(image_path))[0]
output_path = os.path.join(output_folder, f'prediction_{base_name}.png')
vis_image.save(output_path)
print(f" Saved: {output_path}")
return pred_rows, pred_cols
def main():
import argparse
parser = argparse.ArgumentParser(description='Visualize table split predictions on local images')
parser.add_argument('--image-folder', type=str, required=True,
help='Path to folder containing images')
parser.add_argument('--output-folder', type=str, default='predictions_output',
help='Path to folder for saving predictions (default: predictions_output)')
parser.add_argument('--model-path', type=str, default=None,
help='Path to trained model checkpoint (if not specified, will search common locations)')
parser.add_argument('--threshold', type=float, default=0.5,
help='Threshold for binary predictions (default: 0.5)')
parser.add_argument('--num-images', type=int, default=-1,
help='Number of images to process (-1 for all)')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Create output folder
os.makedirs(args.output_folder, exist_ok=True)
print(f"Output folder: {args.output_folder}")
# Load dataset from local folder
print(f"\nLoading images from: {args.image_folder}")
dataset = LocalImageDataset(args.image_folder)
# Load model
print("\nLoading model...")
model = SplitModel().to(device)
# Try to find the best model
if args.model_path:
checkpoint_path = args.model_path
else:
possible_paths = [
'best_split_model.pth',
'/home/ng6309/datascience/santhosh/experiments/tablet/best_split_model.pth',
'/home/ng6309/datascience/santhosh/experiments/tablet/runs/tablet_split_20251006_214335/best_split_model.pth',
]
checkpoint_path = None
for path in possible_paths:
if os.path.exists(path):
checkpoint_path = path
break
if checkpoint_path is None or not os.path.exists(checkpoint_path):
print("ERROR: No trained model found! Please specify --model-path or ensure model exists in default locations")
return
print(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\nModel trained for {checkpoint['epoch']} epochs")
print(f"Val loss: {checkpoint['val_loss']:.4f}")
if 'val_h_f1' in checkpoint:
print(f"Val H F1: {checkpoint['val_h_f1']:.3f}")
print(f"Val V F1: {checkpoint['val_v_f1']:.3f}")
# Determine number of images to process
num_samples = len(dataset) if args.num_images == -1 else min(args.num_images, len(dataset))
# Visualize images
print("\n" + "="*60)
print(f"Visualizing predictions for {num_samples} images")
print("="*60)
for idx in range(num_samples):
pred_rows, pred_cols = visualize_prediction(model, dataset, idx, device, args.output_folder, threshold=args.threshold)
print("\n" + "="*60)
print(f"Completed processing {num_samples} images")
print(f"All predictions saved to: {args.output_folder}")
print("="*60)
if __name__ == "__main__":
main()