| | import os |
| | import json |
| | import glob |
| | import xml.etree.ElementTree as ET |
| | import numpy as np |
| | from PIL import Image |
| | from torch.utils.data import Dataset, DataLoader |
| | import torchvision.transforms as T |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from shapely.geometry import Polygon |
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import json |
| |
|
| | def flat_corners_from_mockup(mockup_path): |
| | """ |
| | Returns 4 corners of print area from mockup.json |
| | ordered TL, TR, BR, BL and normalized [0,1] w.r.t background. |
| | """ |
| | d = json.loads(Path(mockup_path).read_text()) |
| | bg_w = d["background"]["width"] |
| | bg_h = d["background"]["height"] |
| | area = d["printAreas"][0] |
| | x, y = area["position"]["x"], area["position"]["y"] |
| | w, h = area["width"], area["height"] |
| | angle = area["rotation"] |
| | cx, cy = x + w/2.0, y + h/2.0 |
| |
|
| | |
| | dx, dy = w/2.0, h/2.0 |
| | corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32) |
| | theta = np.deg2rad(angle) |
| | R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32) |
| | rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32) |
| |
|
| | |
| | norm = np.zeros_like(rot) |
| | norm[:,0] = rot[:,0] / bg_w |
| | norm[:,1] = rot[:,1] / bg_h |
| | return rot.astype(np.float32), norm.astype(np.float32) |
| |
|
| | def parse_xml_points(xml_path): |
| | """ |
| | Parse the 4 corner points from the XML (FourPoint transform). |
| | Returns normalized coordinates (TL, TR, BR, BL). |
| | """ |
| | tree = ET.parse(xml_path) |
| | root = tree.getroot() |
| |
|
| | points = [] |
| | bg_w = int(root.find("background").get("width")) |
| | bg_h = int(root.find("background").get("height")) |
| |
|
| | for transform in root.findall(".//transform"): |
| | if transform.get("type") == "FourPoint": |
| | for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]: |
| | node = transform.find(f".//point[@type='{pt}']") |
| | if node is not None: |
| | x = float(node.get("x")) / bg_w |
| | y = float(node.get("y")) / bg_h |
| | points.append([x, y]) |
| | break |
| |
|
| | return np.array(points, dtype=np.float32) |
| |
|
| | class KP4Dataset(Dataset): |
| | def __init__(self, root, img_size=512): |
| | self.root = Path(root) |
| | self.img_size = img_size |
| | self.samples = [] |
| |
|
| | |
| | self.transform = T.Compose([ |
| | T.Resize((img_size, img_size)), |
| | T.ToTensor(), |
| | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| | ]) |
| |
|
| | |
| | for xml_file in self.root.rglob("*.xml"): |
| | if "_visual" not in xml_file.stem: |
| | continue |
| |
|
| | |
| | base = xml_file.stem |
| | img_file = None |
| | for ext in [".png", ".jpg", ".jpeg"]: |
| | cand = xml_file.with_suffix(ext) |
| | if cand.exists(): |
| | img_file = cand |
| | break |
| | if img_file is None: |
| | continue |
| |
|
| | |
| | flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png") |
| | if not flat_img.exists(): |
| | flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg") |
| | if not flat_img.exists(): |
| | continue |
| |
|
| | |
| | json_file = xml_file.parent / "mockup.json" |
| | if not json_file.exists(): |
| | continue |
| |
|
| | self.samples.append((img_file, xml_file, flat_img, json_file)) |
| |
|
| | if not self.samples: |
| | raise RuntimeError(f"No valid samples found under {root}") |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, idx): |
| | img_file, xml_file, flat_img, json_file = self.samples[idx] |
| |
|
| | img = self.transform(Image.open(img_file).convert("RGB")) |
| | flat = self.transform(Image.open(flat_img).convert("RGB")) |
| |
|
| | |
| | _, flat_norm = flat_corners_from_mockup(json_file) |
| | flat_pts = torch.tensor(flat_norm, dtype=torch.float32) |
| |
|
| | |
| | persp_norm = parse_xml_points(xml_file) |
| | persp_pts = torch.tensor(persp_norm, dtype=torch.float32) |
| |
|
| | return { |
| | "persp_img": img, |
| | "flat_img": flat, |
| | "flat_pts": flat_pts, |
| | "persp_pts": persp_pts, |
| | "xml": str(xml_file), |
| | "json": str(json_file), |
| | } |
| |
|
| | |
| | |
| | |
| | class SimpleTransformer(nn.Module): |
| | def __init__(self, d_model=128, nhead=4, num_layers=2): |
| | super().__init__() |
| | self.fc_in = nn.Linear(8, d_model) |
| | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) |
| | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| | self.fc_out = nn.Linear(d_model, 8) |
| |
|
| | def forward(self, x): |
| | x = self.fc_in(x).unsqueeze(1) |
| | x = self.transformer(x) |
| | x = self.fc_out(x).squeeze(1) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| | def mse_loss(pred, gt): |
| | return ((pred-gt)**2).mean() |
| |
|
| | def mean_corner_error(pred, gt, img_w, img_h): |
| | pred_px = pred * torch.tensor([img_w,img_h], device=pred.device) |
| | gt_px = gt * torch.tensor([img_w,img_h], device=gt.device) |
| | err = torch.norm(pred_px-gt_px, dim=-1).mean().item() |
| | return err |
| |
|
| | def iou_quad(pred, gt): |
| | pred_poly = Polygon(pred.tolist()) |
| | gt_poly = Polygon(gt.tolist()) |
| | if not pred_poly.is_valid or not gt_poly.is_valid: |
| | return 0.0 |
| | inter = pred_poly.intersection(gt_poly).area |
| | union = pred_poly.union(gt_poly).area |
| | return inter/union if union > 0 else 0.0 |
| |
|
| |
|
| | |
| | |
| | |
| | def train_model( |
| | train_root, |
| | test_root, |
| | epochs=20, |
| | batch_size=8, |
| | lr=1e-3, |
| | img_size=256, |
| | save_dir="Transformer/checkpoints", |
| | resume_path=None |
| | ): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | train_ds = KP4Dataset(train_root, img_size=img_size) |
| | val_ds = KP4Dataset(test_root, img_size=img_size) |
| | train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) |
| | val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) |
| |
|
| | model = SimpleTransformer().to(device) |
| | optimizer = optim.Adam(model.parameters(), lr=lr) |
| | start_epoch = 0 |
| |
|
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | |
| | if resume_path is not None and os.path.exists(resume_path): |
| | print(f"Loading checkpoint from {resume_path}") |
| | checkpoint = torch.load(resume_path, map_location=device) |
| | model.load_state_dict(checkpoint["model_state"]) |
| | optimizer.load_state_dict(checkpoint["optimizer_state"]) |
| | start_epoch = checkpoint["epoch"] |
| | print(f"Resumed from epoch {start_epoch}") |
| |
|
| | |
| | best_iou = -1.0 |
| | best_model_path = os.path.join(save_dir, "best_model.pth") |
| |
|
| | for epoch in range(start_epoch, epochs): |
| | |
| | model.train() |
| | total_loss = 0 |
| | for batch in train_loader: |
| | flat_pts = batch["flat_pts"].to(device) |
| | persp_pts = batch["persp_pts"].to(device) |
| |
|
| | flat_pts_in = flat_pts.view(flat_pts.size(0), -1) |
| | target = persp_pts.view(persp_pts.size(0), -1) |
| |
|
| | pred = model(flat_pts_in) |
| | loss = mse_loss(pred, target) |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() |
| |
|
| | print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}") |
| |
|
| | |
| | model.eval() |
| | mse_all, ce_all, iou_all = [], [], [] |
| | with torch.no_grad(): |
| | for batch in val_loader: |
| | flat_pts = batch["flat_pts"].to(device) |
| | persp_pts = batch["persp_pts"].to(device) |
| |
|
| | flat_pts_in = flat_pts.view(1, -1) |
| | target = persp_pts.view(1, -1) |
| |
|
| | pred = model(flat_pts_in) |
| | mse_all.append(mse_loss(pred, target).item()) |
| |
|
| | pred_quad = pred.view(4,2).cpu() |
| | gt_quad = persp_pts.view(4,2).cpu() |
| |
|
| | w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1] |
| | ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h)) |
| | iou_all.append(iou_quad(pred_quad, gt_quad)) |
| |
|
| | val_mse = np.mean(mse_all) |
| | val_ce = np.mean(ce_all) |
| | val_iou = np.mean(iou_all) |
| |
|
| | print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}") |
| | if (epoch + 1) % 100 == 0: |
| | |
| | checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth") |
| | torch.save({ |
| | "epoch": epoch+1, |
| | "model_state": model.state_dict(), |
| | "optimizer_state": optimizer.state_dict(), |
| | "val_iou": val_iou, |
| | }, checkpoint_path) |
| | print(f"Checkpoint saved: {checkpoint_path}") |
| |
|
| | |
| | if val_iou > best_iou: |
| | best_iou = val_iou |
| | torch.save({ |
| | "epoch": epoch+1, |
| | "model_state": model.state_dict(), |
| | "optimizer_state": optimizer.state_dict(), |
| | "best_iou": best_iou, |
| | }, best_model_path) |
| | print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})") |
| |
|
| | |
| | final_path = os.path.join(save_dir, "final_model.pth") |
| | torch.save(model.state_dict(), final_path) |
| | print(f"Final model saved at {final_path}") |
| | print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}") |
| |
|
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | model = train_model( |
| | train_root="Transformer/train", |
| | test_root="Transformer/test", |
| | epochs=3000, |
| | batch_size=4, |
| | lr=1e-3, |
| | img_size=256, |
| | resume_path=None |
| | ) |
| |
|