saim1309 commited on
Commit
ec7f44e
·
verified ·
1 Parent(s): 420eb14

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +74 -0
  2. best_model.pth +3 -0
  3. plot.py +105 -0
  4. requirements.txt +15 -0
  5. test.py +118 -0
  6. train.py +314 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tempfile
4
+ import cv2
5
+ from test import predict_one
6
+ from plot import (
7
+ autocrop, get_json_corners, extract_points_from_xml,
8
+ draw_feature_matching, stack_images_side_by_side
9
+ )
10
+
11
+ # Hard-coded model checkpoint path
12
+ MODEL_CKPT = "best_model.pth"
13
+
14
+
15
+ # --------------------
16
+ # Pipeline
17
+ # --------------------
18
+ def run_pipeline(flat_img, pers_img, mockup_json, xml_gt):
19
+ # Temp dir for prediction + result
20
+ tmpdir = tempfile.mkdtemp()
21
+ xml_pred_path = os.path.join(tmpdir, "pred.xml")
22
+ result_path = os.path.join(tmpdir, "result.png")
23
+
24
+ # Run prediction
25
+ predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path)
26
+
27
+ # --- Visualization ---
28
+ img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB))
29
+ img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB))
30
+
31
+ json_pts = get_json_corners(mockup_json)
32
+ gt_pts = extract_points_from_xml(xml_gt)
33
+ pred_pts = extract_points_from_xml(xml_pred_path)
34
+
35
+ match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, draw_boxes=True)
36
+ match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, draw_boxes=True)
37
+
38
+ stacked = stack_images_side_by_side(match_json_gt, match_json_pred)
39
+
40
+ # Save result
41
+ cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR))
42
+
43
+ return result_path, xml_pred_path
44
+
45
+
46
+ # --------------------
47
+ # Gradio UI
48
+ # --------------------
49
+ with gr.Blocks() as demo:
50
+ gr.Markdown("## Mesh Key Point Transformer Demo")
51
+
52
+ with gr.Row():
53
+ flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300)
54
+ pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300)
55
+
56
+ with gr.Row():
57
+ mockup_json_in = gr.File(type="filepath", label="Mockup JSON")
58
+ xml_gt_in = gr.File(type="filepath", label="Ground Truth XML")
59
+
60
+ run_btn = gr.Button("Run Prediction + Visualization")
61
+
62
+ with gr.Row():
63
+ out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600)
64
+ out_xml = gr.File(type="filepath", label="Predicted XML")
65
+
66
+ run_btn.click(
67
+ fn=run_pipeline,
68
+ inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in],
69
+ outputs=[out_img, out_xml]
70
+ )
71
+
72
+
73
+ if __name__ == "__main__":
74
+ demo.launch(share=True)
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d1e28e3d30f64b129be39c8a6f3a2e88f042f8b24e0a1526e77cdd4c27b20f7
3
+ size 14292417
plot.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from lxml import etree
6
+
7
+ # ========= Crop black borders =========
8
+ def autocrop(image, tol=0):
9
+ """Crops black borders from an image."""
10
+ if len(image.shape) == 3:
11
+ mask = (image > tol).any(2)
12
+ else:
13
+ mask = image > tol
14
+ if mask.any():
15
+ coords = np.argwhere(mask)
16
+ y0, x0 = coords.min(axis=0)
17
+ y1, x1 = coords.max(axis=0) + 1
18
+ image = image[y0:y1, x0:x1]
19
+ return image
20
+
21
+ # ========= Stack horizontally =========
22
+ def stack_images_side_by_side(img1, img2):
23
+ """Resizes two images to a common height and stacks them horizontally."""
24
+ target_h = max(img1.shape[0], img2.shape[0])
25
+ w1 = int(img1.shape[1] * (target_h / img1.shape[0]))
26
+ w2 = int(img2.shape[1] * (target_h / img2.shape[0]))
27
+
28
+ img1_resized = cv2.resize(img1, (w1, target_h))
29
+ img2_resized = cv2.resize(img2, (w2, target_h))
30
+
31
+ return np.hstack([img1_resized, img2_resized])
32
+
33
+ # ========= Extract rectangle from JSON =========
34
+ def get_json_corners(json_file):
35
+ """Extracts rotated rectangle corners from mockup.json."""
36
+ with open(json_file, 'r') as f:
37
+ data = json.load(f)
38
+
39
+ area = data['printAreas'][0]
40
+ x, y = area['position']['x'], area['position']['y']
41
+ w, h, angle = area['width'], area['height'], area['rotation']
42
+ cx, cy = x + w / 2, y + h / 2
43
+
44
+ angle_rad = np.radians(angle)
45
+ dx, dy = w / 2, h / 2
46
+ corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]])
47
+ R = np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
48
+ [np.sin(angle_rad), np.cos(angle_rad)]])
49
+ rotated = np.dot(corners, R.T) + np.array([cx, cy])
50
+ return rotated.astype(int)
51
+
52
+ # ========= Extract polygon from XML =========
53
+ def extract_points_from_xml(xml_file):
54
+ """Extracts corner points from a visual.xml file."""
55
+ tree = etree.parse(xml_file)
56
+ root = tree.getroot()
57
+ transform = root.find('.//transform')
58
+ points = {}
59
+ for pt in transform.findall('.//point'):
60
+ points[pt.attrib['type']] = (float(pt.attrib['x']), float(pt.attrib['y']))
61
+ order = ['TopLeft', 'TopRight', 'BottomRight', 'BottomLeft']
62
+ return np.array([points[p] for p in order], dtype=np.float32)
63
+
64
+ # ========= Draw correspondences and (optional) boxes =========
65
+ def draw_feature_matching(img1, pts1, img2, pts2, draw_boxes=True):
66
+ """
67
+ Draws feature correspondences between two images, handling different sizes.
68
+ """
69
+ # Resize images to a common height to avoid black padding bars
70
+ target_h = max(img1.shape[0], img2.shape[0])
71
+
72
+ # Calculate scaling factors and new widths
73
+ scale1 = target_h / img1.shape[0]
74
+ w1_new = int(img1.shape[1] * scale1)
75
+
76
+ scale2 = target_h / img2.shape[0]
77
+ w2_new = int(img2.shape[1] * scale2)
78
+
79
+ # Resize images
80
+ img1_resized = cv2.resize(img1, (w1_new, target_h))
81
+ img2_resized = cv2.resize(img2, (w2_new, target_h))
82
+
83
+ # Scale points to match the resized images
84
+ pts1_scaled = (pts1 * scale1).astype(int)
85
+ pts2_scaled = (pts2 * scale2).astype(int)
86
+
87
+ # Create the combined image canvas
88
+ h, w1, w2 = target_h, w1_new, w2_new
89
+ new_img = np.concatenate([img1_resized, img2_resized], axis=1)
90
+
91
+ # Optional: Draw polygons (boxes)
92
+ if draw_boxes:
93
+ cv2.polylines(new_img, [pts1_scaled.reshape((-1,1,2))], True, (0,255,0), 3)
94
+ cv2.polylines(new_img, [pts2_scaled.reshape((-1,1,2)) + np.array([w1,0])], True, (0,255,0), 3)
95
+
96
+ # Draw correspondences
97
+ for (x1, y1), (x2, y2) in zip(pts1_scaled, pts2_scaled):
98
+ color = tuple(np.random.randint(0, 255, 3).tolist())
99
+ cv2.circle(new_img, (x1, y1), 6, color, -1)
100
+ cv2.circle(new_img, (x2 + w1, y2), 6, color, -1)
101
+ cv2.line(new_img, (x1, y1), (x2 + w1, y2), color, 2)
102
+
103
+ return new_img
104
+
105
+
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML / DL
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ numpy
6
+ opencv-python
7
+ shapely
8
+ pathlib
9
+ gradio
10
+ fastapi
11
+ starlette
12
+ pydantic
13
+ uvicorn
14
+ matplotlib
15
+ lxml
test.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import xml.etree.ElementTree as ET
3
+ from PIL import Image
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import json
7
+
8
+ from train import SimpleTransformer, flat_corners_from_mockup
9
+
10
+ # --------------------
11
+ # Utility: order 4 points (same as old)
12
+ # --------------------
13
+ def order_points_clockwise(pts):
14
+ pts = np.array(pts, dtype="float32")
15
+ y_sorted = pts[np.argsort(pts[:, 1]), :]
16
+
17
+ top_two = y_sorted[:2, :]
18
+ bottom_two = y_sorted[2:, :]
19
+
20
+ if top_two[0][0] < top_two[1][0]:
21
+ tl, tr = top_two
22
+ else:
23
+ tr, tl = top_two
24
+
25
+ if bottom_two[0][0] < bottom_two[1][0]:
26
+ bl, br = bottom_two
27
+ else:
28
+ br, bl = bottom_two
29
+
30
+ return np.array([tl, tr, br, bl], dtype="float32")
31
+
32
+ # --------------------
33
+ # Utility: save XML prediction
34
+ # --------------------
35
+ def save_prediction_xml(pred_pts, out_path, img_w, img_h):
36
+ ordered = order_points_clockwise(pred_pts)
37
+ TL, TR, BR, BL = ordered
38
+
39
+ root = ET.Element("visualization", version="1.0")
40
+ ET.SubElement(root, "effects", surfacecolor="", iswood="0")
41
+ ET.SubElement(root, "background",
42
+ width=str(img_w), height=str(img_h),
43
+ color1="#C4CDE4", color2="", color3="")
44
+
45
+ transforms_node = ET.SubElement(root, "transforms")
46
+ transform = ET.SubElement(transforms_node, "transform",
47
+ type="FourPoint", offsetX="0", offsetY="0", offsetZ="0.0",
48
+ rotationX="0.0", rotationY="0.0", rotationZ="0.0",
49
+ name="Region", posCode="REGION", posName="Region",
50
+ posDef="0", techCode="EMBF03", techName="Embroidery Fixed",
51
+ techDef="0", areaWidth="100", areaHeight="100",
52
+ maxColors="12", defaultLogoSize="100", sizeX="100", sizeY="100")
53
+
54
+ pts = {"TopLeft": TL, "TopRight": TR, "BottomRight": BR, "BottomLeft": BL}
55
+ for ptype, (x, y) in pts.items():
56
+ ET.SubElement(transform, "point",
57
+ type=ptype, x=str(float(x)), y=str(float(y)),
58
+ z="0.0", warp="0", warpShift="0")
59
+
60
+ overlays = ET.SubElement(root, "overlays")
61
+ overlay = ET.SubElement(overlays, "overlay")
62
+ for (x, y) in ordered:
63
+ ET.SubElement(overlay, "point", type="Next", x=str(float(x)), y=str(float(y)), z="0.0")
64
+
65
+ ET.SubElement(root, "ruler",
66
+ startX=str(TL[0]), startY=str(TL[1]),
67
+ stopX=str(BR[0]), stopY=str(BR[1]), value="100")
68
+
69
+ tree = ET.ElementTree(root)
70
+ tree.write(out_path, encoding="utf-8", xml_declaration=True)
71
+
72
+
73
+ # --------------------
74
+ # Predict one sample
75
+ # --------------------
76
+ def predict_one(mockup_json, pers_img_path, model_ckpt, out_path="prediction.xml"):
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+ # Load perspective image
80
+ pers_img = Image.open(pers_img_path).convert("RGB")
81
+ orig_w, orig_h = pers_img.size
82
+
83
+ # Load flat points from mockup.json
84
+ _, flat_norm = flat_corners_from_mockup(mockup_json)
85
+ flat_in = torch.tensor(flat_norm.flatten(), dtype=torch.float32).unsqueeze(0).to(device) # (1,8)
86
+
87
+ # Load model
88
+ model = SimpleTransformer().to(device)
89
+ state = torch.load(model_ckpt, map_location=device, weights_only=False)
90
+ if "model_state" in state: # resume checkpoint format
91
+ model.load_state_dict(state["model_state"])
92
+ else: # final model
93
+ model.load_state_dict(state)
94
+ model.eval()
95
+
96
+ # Predict
97
+ with torch.no_grad():
98
+ pred = model(flat_in) # (1,8)
99
+ pred = pred.view(4, 2).cpu().numpy()
100
+
101
+ # Convert normalized coords to pixel coords
102
+ pred_px = pred.copy()
103
+ pred_px[:, 0] *= orig_w
104
+ pred_px[:, 1] *= orig_h
105
+
106
+ # Save prediction
107
+ save_prediction_xml(pred_px, out_path, orig_w, orig_h)
108
+ print(f"Saved prediction -> {out_path}")
109
+
110
+
111
+ # --------------------
112
+ # Example usage
113
+ # --------------------
114
+ if __name__ == "__main__":
115
+ mockup_json = "Transformer/test/100847_TD/front/LAS02/mockup.json"
116
+ pers_img = "Transformer/test/100847_TD/front/LAS02/4BC13E58-1D8A-4E5D-8A40-C1F4B1248893_visual.jpg"
117
+ model_ckpt = "Transformer/transformer_model.pth"
118
+ predict_one(mockup_json, pers_img, model_ckpt, out_path="Transformer/Prediction/pred3.xml")
train.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import glob
4
+ import xml.etree.ElementTree as ET
5
+ import numpy as np
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import torchvision.transforms as T
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from shapely.geometry import Polygon
13
+ from pathlib import Path
14
+
15
+ # =====================
16
+ # Data Utils
17
+ # # =====================
18
+
19
+ import numpy as np
20
+ import json
21
+
22
+ def flat_corners_from_mockup(mockup_path):
23
+ """
24
+ Returns 4 corners of print area from mockup.json
25
+ ordered TL, TR, BR, BL and normalized [0,1] w.r.t background.
26
+ """
27
+ d = json.loads(Path(mockup_path).read_text())
28
+ bg_w = d["background"]["width"]
29
+ bg_h = d["background"]["height"]
30
+ area = d["printAreas"][0]
31
+ x, y = area["position"]["x"], area["position"]["y"]
32
+ w, h = area["width"], area["height"]
33
+ angle = area["rotation"]
34
+ cx, cy = x + w/2.0, y + h/2.0
35
+
36
+ # corners in px (TL,TR,BR,BL)
37
+ dx, dy = w/2.0, h/2.0
38
+ corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32)
39
+ theta = np.deg2rad(angle)
40
+ R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32)
41
+ rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32)
42
+
43
+ # normalize
44
+ norm = np.zeros_like(rot)
45
+ norm[:,0] = rot[:,0] / bg_w
46
+ norm[:,1] = rot[:,1] / bg_h
47
+ return rot.astype(np.float32), norm.astype(np.float32)
48
+
49
+ def parse_xml_points(xml_path):
50
+ """
51
+ Parse the 4 corner points from the XML (FourPoint transform).
52
+ Returns normalized coordinates (TL, TR, BR, BL).
53
+ """
54
+ tree = ET.parse(xml_path)
55
+ root = tree.getroot()
56
+
57
+ points = []
58
+ bg_w = int(root.find("background").get("width"))
59
+ bg_h = int(root.find("background").get("height"))
60
+
61
+ for transform in root.findall(".//transform"):
62
+ if transform.get("type") == "FourPoint":
63
+ for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]:
64
+ node = transform.find(f".//point[@type='{pt}']")
65
+ if node is not None:
66
+ x = float(node.get("x")) / bg_w
67
+ y = float(node.get("y")) / bg_h
68
+ points.append([x, y])
69
+ break # only first transform
70
+
71
+ return np.array(points, dtype=np.float32) # (4,2)
72
+
73
+ class KP4Dataset(Dataset):
74
+ def __init__(self, root, img_size=512):
75
+ self.root = Path(root)
76
+ self.img_size = img_size
77
+ self.samples = []
78
+
79
+ # Transform pipeline (resize + tensor + normalize)
80
+ self.transform = T.Compose([
81
+ T.Resize((img_size, img_size)),
82
+ T.ToTensor(),
83
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
84
+ ])
85
+
86
+ # Walk recursively
87
+ for xml_file in self.root.rglob("*.xml"):
88
+ if "_visual" not in xml_file.stem:
89
+ continue
90
+
91
+ # Find matching perspective image
92
+ base = xml_file.stem
93
+ img_file = None
94
+ for ext in [".png", ".jpg", ".jpeg"]:
95
+ cand = xml_file.with_suffix(ext)
96
+ if cand.exists():
97
+ img_file = cand
98
+ break
99
+ if img_file is None:
100
+ continue
101
+
102
+ # Flat image (background)
103
+ flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png")
104
+ if not flat_img.exists():
105
+ flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg")
106
+ if not flat_img.exists():
107
+ continue
108
+
109
+ # Mockup.json
110
+ json_file = xml_file.parent / "mockup.json"
111
+ if not json_file.exists():
112
+ continue
113
+
114
+ self.samples.append((img_file, xml_file, flat_img, json_file))
115
+
116
+ if not self.samples:
117
+ raise RuntimeError(f"No valid samples found under {root}")
118
+
119
+ def __len__(self):
120
+ return len(self.samples)
121
+
122
+ def __getitem__(self, idx):
123
+ img_file, xml_file, flat_img, json_file = self.samples[idx]
124
+
125
+ img = self.transform(Image.open(img_file).convert("RGB"))
126
+ flat = self.transform(Image.open(flat_img).convert("RGB"))
127
+
128
+ # flat points
129
+ _, flat_norm = flat_corners_from_mockup(json_file)
130
+ flat_pts = torch.tensor(flat_norm, dtype=torch.float32)
131
+
132
+ # perspective points
133
+ persp_norm = parse_xml_points(xml_file)
134
+ persp_pts = torch.tensor(persp_norm, dtype=torch.float32)
135
+
136
+ return {
137
+ "persp_img": img,
138
+ "flat_img": flat,
139
+ "flat_pts": flat_pts,
140
+ "persp_pts": persp_pts,
141
+ "xml": str(xml_file),
142
+ "json": str(json_file),
143
+ }
144
+
145
+ # =====================
146
+ # Model
147
+ # =====================
148
+ class SimpleTransformer(nn.Module):
149
+ def __init__(self, d_model=128, nhead=4, num_layers=2):
150
+ super().__init__()
151
+ self.fc_in = nn.Linear(8, d_model) # 4 corners * 2
152
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
153
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
154
+ self.fc_out = nn.Linear(d_model, 8) # predict 4 corners (x,y)*4
155
+
156
+ def forward(self, x):
157
+ x = self.fc_in(x).unsqueeze(1) # (B,1,8)->(B,1,d_model)
158
+ x = self.transformer(x)
159
+ x = self.fc_out(x).squeeze(1) # (B,d_model)->(B,8)
160
+ return x
161
+
162
+
163
+ # =====================
164
+ # Metrics
165
+ # =====================
166
+ def mse_loss(pred, gt):
167
+ return ((pred-gt)**2).mean()
168
+
169
+ def mean_corner_error(pred, gt, img_w, img_h):
170
+ pred_px = pred * torch.tensor([img_w,img_h], device=pred.device)
171
+ gt_px = gt * torch.tensor([img_w,img_h], device=gt.device)
172
+ err = torch.norm(pred_px-gt_px, dim=-1).mean().item()
173
+ return err
174
+
175
+ def iou_quad(pred, gt):
176
+ pred_poly = Polygon(pred.tolist())
177
+ gt_poly = Polygon(gt.tolist())
178
+ if not pred_poly.is_valid or not gt_poly.is_valid:
179
+ return 0.0
180
+ inter = pred_poly.intersection(gt_poly).area
181
+ union = pred_poly.union(gt_poly).area
182
+ return inter/union if union > 0 else 0.0
183
+
184
+
185
+ # =====================
186
+ # Training
187
+ # =====================
188
+ def train_model(
189
+ train_root,
190
+ test_root,
191
+ epochs=20,
192
+ batch_size=8,
193
+ lr=1e-3,
194
+ img_size=256,
195
+ save_dir="Transformer/checkpoints",
196
+ resume_path=None
197
+ ):
198
+ device = "cuda" if torch.cuda.is_available() else "cpu"
199
+
200
+ train_ds = KP4Dataset(train_root, img_size=img_size)
201
+ val_ds = KP4Dataset(test_root, img_size=img_size)
202
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
203
+ val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
204
+
205
+ model = SimpleTransformer().to(device)
206
+ optimizer = optim.Adam(model.parameters(), lr=lr)
207
+ start_epoch = 0
208
+
209
+ os.makedirs(save_dir, exist_ok=True)
210
+
211
+ # Resume Training
212
+ if resume_path is not None and os.path.exists(resume_path):
213
+ print(f"Loading checkpoint from {resume_path}")
214
+ checkpoint = torch.load(resume_path, map_location=device)
215
+ model.load_state_dict(checkpoint["model_state"])
216
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
217
+ start_epoch = checkpoint["epoch"]
218
+ print(f"Resumed from epoch {start_epoch}")
219
+
220
+ # ===================== Track Best Model =====================
221
+ best_iou = -1.0
222
+ best_model_path = os.path.join(save_dir, "best_model.pth")
223
+
224
+ for epoch in range(start_epoch, epochs):
225
+ # -------- Training --------
226
+ model.train()
227
+ total_loss = 0
228
+ for batch in train_loader:
229
+ flat_pts = batch["flat_pts"].to(device)
230
+ persp_pts = batch["persp_pts"].to(device)
231
+
232
+ flat_pts_in = flat_pts.view(flat_pts.size(0), -1)
233
+ target = persp_pts.view(persp_pts.size(0), -1)
234
+
235
+ pred = model(flat_pts_in)
236
+ loss = mse_loss(pred, target)
237
+
238
+ optimizer.zero_grad()
239
+ loss.backward()
240
+ optimizer.step()
241
+ total_loss += loss.item()
242
+
243
+ print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}")
244
+
245
+ # -------- Validation --------
246
+ model.eval()
247
+ mse_all, ce_all, iou_all = [], [], []
248
+ with torch.no_grad():
249
+ for batch in val_loader:
250
+ flat_pts = batch["flat_pts"].to(device)
251
+ persp_pts = batch["persp_pts"].to(device)
252
+
253
+ flat_pts_in = flat_pts.view(1, -1)
254
+ target = persp_pts.view(1, -1)
255
+
256
+ pred = model(flat_pts_in)
257
+ mse_all.append(mse_loss(pred, target).item())
258
+
259
+ pred_quad = pred.view(4,2).cpu()
260
+ gt_quad = persp_pts.view(4,2).cpu()
261
+
262
+ w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1]
263
+ ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h))
264
+ iou_all.append(iou_quad(pred_quad, gt_quad))
265
+
266
+ val_mse = np.mean(mse_all)
267
+ val_ce = np.mean(ce_all)
268
+ val_iou = np.mean(iou_all)
269
+
270
+ print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}")
271
+ if (epoch + 1) % 100 == 0:
272
+ # -------- Save Epoch Checkpoint (like before) --------
273
+ checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth")
274
+ torch.save({
275
+ "epoch": epoch+1,
276
+ "model_state": model.state_dict(),
277
+ "optimizer_state": optimizer.state_dict(),
278
+ "val_iou": val_iou,
279
+ }, checkpoint_path)
280
+ print(f"Checkpoint saved: {checkpoint_path}")
281
+
282
+ # -------- Save Best Model --------
283
+ if val_iou > best_iou:
284
+ best_iou = val_iou
285
+ torch.save({
286
+ "epoch": epoch+1,
287
+ "model_state": model.state_dict(),
288
+ "optimizer_state": optimizer.state_dict(),
289
+ "best_iou": best_iou,
290
+ }, best_model_path)
291
+ print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})")
292
+
293
+ # Save final model weights
294
+ final_path = os.path.join(save_dir, "final_model.pth")
295
+ torch.save(model.state_dict(), final_path)
296
+ print(f"Final model saved at {final_path}")
297
+ print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}")
298
+
299
+ return model
300
+
301
+
302
+ # =====================
303
+ # Main
304
+ # =====================
305
+ if __name__ == "__main__":
306
+ model = train_model(
307
+ train_root="Transformer/train",
308
+ test_root="Transformer/test",
309
+ epochs=3000,
310
+ batch_size=4,
311
+ lr=1e-3,
312
+ img_size=256,
313
+ resume_path=None
314
+ )