Upload 6 files
Browse files- app.py +74 -0
- best_model.pth +3 -0
- plot.py +105 -0
- requirements.txt +15 -0
- test.py +118 -0
- train.py +314 -0
app.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import cv2
|
| 5 |
+
from test import predict_one
|
| 6 |
+
from plot import (
|
| 7 |
+
autocrop, get_json_corners, extract_points_from_xml,
|
| 8 |
+
draw_feature_matching, stack_images_side_by_side
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Hard-coded model checkpoint path
|
| 12 |
+
MODEL_CKPT = "best_model.pth"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# --------------------
|
| 16 |
+
# Pipeline
|
| 17 |
+
# --------------------
|
| 18 |
+
def run_pipeline(flat_img, pers_img, mockup_json, xml_gt):
|
| 19 |
+
# Temp dir for prediction + result
|
| 20 |
+
tmpdir = tempfile.mkdtemp()
|
| 21 |
+
xml_pred_path = os.path.join(tmpdir, "pred.xml")
|
| 22 |
+
result_path = os.path.join(tmpdir, "result.png")
|
| 23 |
+
|
| 24 |
+
# Run prediction
|
| 25 |
+
predict_one(mockup_json, pers_img, MODEL_CKPT, out_path=xml_pred_path)
|
| 26 |
+
|
| 27 |
+
# --- Visualization ---
|
| 28 |
+
img_json = autocrop(cv2.cvtColor(cv2.imread(flat_img), cv2.COLOR_BGR2RGB))
|
| 29 |
+
img_xml = autocrop(cv2.cvtColor(cv2.imread(pers_img), cv2.COLOR_BGR2RGB))
|
| 30 |
+
|
| 31 |
+
json_pts = get_json_corners(mockup_json)
|
| 32 |
+
gt_pts = extract_points_from_xml(xml_gt)
|
| 33 |
+
pred_pts = extract_points_from_xml(xml_pred_path)
|
| 34 |
+
|
| 35 |
+
match_json_gt = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), gt_pts, draw_boxes=True)
|
| 36 |
+
match_json_pred = draw_feature_matching(img_json.copy(), json_pts, img_xml.copy(), pred_pts, draw_boxes=True)
|
| 37 |
+
|
| 38 |
+
stacked = stack_images_side_by_side(match_json_gt, match_json_pred)
|
| 39 |
+
|
| 40 |
+
# Save result
|
| 41 |
+
cv2.imwrite(result_path, cv2.cvtColor(stacked, cv2.COLOR_RGB2BGR))
|
| 42 |
+
|
| 43 |
+
return result_path, xml_pred_path
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --------------------
|
| 47 |
+
# Gradio UI
|
| 48 |
+
# --------------------
|
| 49 |
+
with gr.Blocks() as demo:
|
| 50 |
+
gr.Markdown("## Mesh Key Point Transformer Demo")
|
| 51 |
+
|
| 52 |
+
with gr.Row():
|
| 53 |
+
flat_in = gr.Image(type="filepath", label="Flat Image", width=300, height=300)
|
| 54 |
+
pers_in = gr.Image(type="filepath", label="Perspective Image", width=300, height=300)
|
| 55 |
+
|
| 56 |
+
with gr.Row():
|
| 57 |
+
mockup_json_in = gr.File(type="filepath", label="Mockup JSON")
|
| 58 |
+
xml_gt_in = gr.File(type="filepath", label="Ground Truth XML")
|
| 59 |
+
|
| 60 |
+
run_btn = gr.Button("Run Prediction + Visualization")
|
| 61 |
+
|
| 62 |
+
with gr.Row():
|
| 63 |
+
out_img = gr.Image(type="filepath", label="Comparison Output", width=800, height=600)
|
| 64 |
+
out_xml = gr.File(type="filepath", label="Predicted XML")
|
| 65 |
+
|
| 66 |
+
run_btn.click(
|
| 67 |
+
fn=run_pipeline,
|
| 68 |
+
inputs=[flat_in, pers_in, mockup_json_in, xml_gt_in],
|
| 69 |
+
outputs=[out_img, out_xml]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
demo.launch(share=True)
|
best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d1e28e3d30f64b129be39c8a6f3a2e88f042f8b24e0a1526e77cdd4c27b20f7
|
| 3 |
+
size 14292417
|
plot.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from lxml import etree
|
| 6 |
+
|
| 7 |
+
# ========= Crop black borders =========
|
| 8 |
+
def autocrop(image, tol=0):
|
| 9 |
+
"""Crops black borders from an image."""
|
| 10 |
+
if len(image.shape) == 3:
|
| 11 |
+
mask = (image > tol).any(2)
|
| 12 |
+
else:
|
| 13 |
+
mask = image > tol
|
| 14 |
+
if mask.any():
|
| 15 |
+
coords = np.argwhere(mask)
|
| 16 |
+
y0, x0 = coords.min(axis=0)
|
| 17 |
+
y1, x1 = coords.max(axis=0) + 1
|
| 18 |
+
image = image[y0:y1, x0:x1]
|
| 19 |
+
return image
|
| 20 |
+
|
| 21 |
+
# ========= Stack horizontally =========
|
| 22 |
+
def stack_images_side_by_side(img1, img2):
|
| 23 |
+
"""Resizes two images to a common height and stacks them horizontally."""
|
| 24 |
+
target_h = max(img1.shape[0], img2.shape[0])
|
| 25 |
+
w1 = int(img1.shape[1] * (target_h / img1.shape[0]))
|
| 26 |
+
w2 = int(img2.shape[1] * (target_h / img2.shape[0]))
|
| 27 |
+
|
| 28 |
+
img1_resized = cv2.resize(img1, (w1, target_h))
|
| 29 |
+
img2_resized = cv2.resize(img2, (w2, target_h))
|
| 30 |
+
|
| 31 |
+
return np.hstack([img1_resized, img2_resized])
|
| 32 |
+
|
| 33 |
+
# ========= Extract rectangle from JSON =========
|
| 34 |
+
def get_json_corners(json_file):
|
| 35 |
+
"""Extracts rotated rectangle corners from mockup.json."""
|
| 36 |
+
with open(json_file, 'r') as f:
|
| 37 |
+
data = json.load(f)
|
| 38 |
+
|
| 39 |
+
area = data['printAreas'][0]
|
| 40 |
+
x, y = area['position']['x'], area['position']['y']
|
| 41 |
+
w, h, angle = area['width'], area['height'], area['rotation']
|
| 42 |
+
cx, cy = x + w / 2, y + h / 2
|
| 43 |
+
|
| 44 |
+
angle_rad = np.radians(angle)
|
| 45 |
+
dx, dy = w / 2, h / 2
|
| 46 |
+
corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]])
|
| 47 |
+
R = np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
|
| 48 |
+
[np.sin(angle_rad), np.cos(angle_rad)]])
|
| 49 |
+
rotated = np.dot(corners, R.T) + np.array([cx, cy])
|
| 50 |
+
return rotated.astype(int)
|
| 51 |
+
|
| 52 |
+
# ========= Extract polygon from XML =========
|
| 53 |
+
def extract_points_from_xml(xml_file):
|
| 54 |
+
"""Extracts corner points from a visual.xml file."""
|
| 55 |
+
tree = etree.parse(xml_file)
|
| 56 |
+
root = tree.getroot()
|
| 57 |
+
transform = root.find('.//transform')
|
| 58 |
+
points = {}
|
| 59 |
+
for pt in transform.findall('.//point'):
|
| 60 |
+
points[pt.attrib['type']] = (float(pt.attrib['x']), float(pt.attrib['y']))
|
| 61 |
+
order = ['TopLeft', 'TopRight', 'BottomRight', 'BottomLeft']
|
| 62 |
+
return np.array([points[p] for p in order], dtype=np.float32)
|
| 63 |
+
|
| 64 |
+
# ========= Draw correspondences and (optional) boxes =========
|
| 65 |
+
def draw_feature_matching(img1, pts1, img2, pts2, draw_boxes=True):
|
| 66 |
+
"""
|
| 67 |
+
Draws feature correspondences between two images, handling different sizes.
|
| 68 |
+
"""
|
| 69 |
+
# Resize images to a common height to avoid black padding bars
|
| 70 |
+
target_h = max(img1.shape[0], img2.shape[0])
|
| 71 |
+
|
| 72 |
+
# Calculate scaling factors and new widths
|
| 73 |
+
scale1 = target_h / img1.shape[0]
|
| 74 |
+
w1_new = int(img1.shape[1] * scale1)
|
| 75 |
+
|
| 76 |
+
scale2 = target_h / img2.shape[0]
|
| 77 |
+
w2_new = int(img2.shape[1] * scale2)
|
| 78 |
+
|
| 79 |
+
# Resize images
|
| 80 |
+
img1_resized = cv2.resize(img1, (w1_new, target_h))
|
| 81 |
+
img2_resized = cv2.resize(img2, (w2_new, target_h))
|
| 82 |
+
|
| 83 |
+
# Scale points to match the resized images
|
| 84 |
+
pts1_scaled = (pts1 * scale1).astype(int)
|
| 85 |
+
pts2_scaled = (pts2 * scale2).astype(int)
|
| 86 |
+
|
| 87 |
+
# Create the combined image canvas
|
| 88 |
+
h, w1, w2 = target_h, w1_new, w2_new
|
| 89 |
+
new_img = np.concatenate([img1_resized, img2_resized], axis=1)
|
| 90 |
+
|
| 91 |
+
# Optional: Draw polygons (boxes)
|
| 92 |
+
if draw_boxes:
|
| 93 |
+
cv2.polylines(new_img, [pts1_scaled.reshape((-1,1,2))], True, (0,255,0), 3)
|
| 94 |
+
cv2.polylines(new_img, [pts2_scaled.reshape((-1,1,2)) + np.array([w1,0])], True, (0,255,0), 3)
|
| 95 |
+
|
| 96 |
+
# Draw correspondences
|
| 97 |
+
for (x1, y1), (x2, y2) in zip(pts1_scaled, pts2_scaled):
|
| 98 |
+
color = tuple(np.random.randint(0, 255, 3).tolist())
|
| 99 |
+
cv2.circle(new_img, (x1, y1), 6, color, -1)
|
| 100 |
+
cv2.circle(new_img, (x2 + w1, y2), 6, color, -1)
|
| 101 |
+
cv2.line(new_img, (x1, y1), (x2 + w1, y2), color, 2)
|
| 102 |
+
|
| 103 |
+
return new_img
|
| 104 |
+
|
| 105 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML / DL
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
Pillow
|
| 5 |
+
numpy
|
| 6 |
+
opencv-python
|
| 7 |
+
shapely
|
| 8 |
+
pathlib
|
| 9 |
+
gradio
|
| 10 |
+
fastapi
|
| 11 |
+
starlette
|
| 12 |
+
pydantic
|
| 13 |
+
uvicorn
|
| 14 |
+
matplotlib
|
| 15 |
+
lxml
|
test.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import xml.etree.ElementTree as ET
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
from train import SimpleTransformer, flat_corners_from_mockup
|
| 9 |
+
|
| 10 |
+
# --------------------
|
| 11 |
+
# Utility: order 4 points (same as old)
|
| 12 |
+
# --------------------
|
| 13 |
+
def order_points_clockwise(pts):
|
| 14 |
+
pts = np.array(pts, dtype="float32")
|
| 15 |
+
y_sorted = pts[np.argsort(pts[:, 1]), :]
|
| 16 |
+
|
| 17 |
+
top_two = y_sorted[:2, :]
|
| 18 |
+
bottom_two = y_sorted[2:, :]
|
| 19 |
+
|
| 20 |
+
if top_two[0][0] < top_two[1][0]:
|
| 21 |
+
tl, tr = top_two
|
| 22 |
+
else:
|
| 23 |
+
tr, tl = top_two
|
| 24 |
+
|
| 25 |
+
if bottom_two[0][0] < bottom_two[1][0]:
|
| 26 |
+
bl, br = bottom_two
|
| 27 |
+
else:
|
| 28 |
+
br, bl = bottom_two
|
| 29 |
+
|
| 30 |
+
return np.array([tl, tr, br, bl], dtype="float32")
|
| 31 |
+
|
| 32 |
+
# --------------------
|
| 33 |
+
# Utility: save XML prediction
|
| 34 |
+
# --------------------
|
| 35 |
+
def save_prediction_xml(pred_pts, out_path, img_w, img_h):
|
| 36 |
+
ordered = order_points_clockwise(pred_pts)
|
| 37 |
+
TL, TR, BR, BL = ordered
|
| 38 |
+
|
| 39 |
+
root = ET.Element("visualization", version="1.0")
|
| 40 |
+
ET.SubElement(root, "effects", surfacecolor="", iswood="0")
|
| 41 |
+
ET.SubElement(root, "background",
|
| 42 |
+
width=str(img_w), height=str(img_h),
|
| 43 |
+
color1="#C4CDE4", color2="", color3="")
|
| 44 |
+
|
| 45 |
+
transforms_node = ET.SubElement(root, "transforms")
|
| 46 |
+
transform = ET.SubElement(transforms_node, "transform",
|
| 47 |
+
type="FourPoint", offsetX="0", offsetY="0", offsetZ="0.0",
|
| 48 |
+
rotationX="0.0", rotationY="0.0", rotationZ="0.0",
|
| 49 |
+
name="Region", posCode="REGION", posName="Region",
|
| 50 |
+
posDef="0", techCode="EMBF03", techName="Embroidery Fixed",
|
| 51 |
+
techDef="0", areaWidth="100", areaHeight="100",
|
| 52 |
+
maxColors="12", defaultLogoSize="100", sizeX="100", sizeY="100")
|
| 53 |
+
|
| 54 |
+
pts = {"TopLeft": TL, "TopRight": TR, "BottomRight": BR, "BottomLeft": BL}
|
| 55 |
+
for ptype, (x, y) in pts.items():
|
| 56 |
+
ET.SubElement(transform, "point",
|
| 57 |
+
type=ptype, x=str(float(x)), y=str(float(y)),
|
| 58 |
+
z="0.0", warp="0", warpShift="0")
|
| 59 |
+
|
| 60 |
+
overlays = ET.SubElement(root, "overlays")
|
| 61 |
+
overlay = ET.SubElement(overlays, "overlay")
|
| 62 |
+
for (x, y) in ordered:
|
| 63 |
+
ET.SubElement(overlay, "point", type="Next", x=str(float(x)), y=str(float(y)), z="0.0")
|
| 64 |
+
|
| 65 |
+
ET.SubElement(root, "ruler",
|
| 66 |
+
startX=str(TL[0]), startY=str(TL[1]),
|
| 67 |
+
stopX=str(BR[0]), stopY=str(BR[1]), value="100")
|
| 68 |
+
|
| 69 |
+
tree = ET.ElementTree(root)
|
| 70 |
+
tree.write(out_path, encoding="utf-8", xml_declaration=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# --------------------
|
| 74 |
+
# Predict one sample
|
| 75 |
+
# --------------------
|
| 76 |
+
def predict_one(mockup_json, pers_img_path, model_ckpt, out_path="prediction.xml"):
|
| 77 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 78 |
+
|
| 79 |
+
# Load perspective image
|
| 80 |
+
pers_img = Image.open(pers_img_path).convert("RGB")
|
| 81 |
+
orig_w, orig_h = pers_img.size
|
| 82 |
+
|
| 83 |
+
# Load flat points from mockup.json
|
| 84 |
+
_, flat_norm = flat_corners_from_mockup(mockup_json)
|
| 85 |
+
flat_in = torch.tensor(flat_norm.flatten(), dtype=torch.float32).unsqueeze(0).to(device) # (1,8)
|
| 86 |
+
|
| 87 |
+
# Load model
|
| 88 |
+
model = SimpleTransformer().to(device)
|
| 89 |
+
state = torch.load(model_ckpt, map_location=device, weights_only=False)
|
| 90 |
+
if "model_state" in state: # resume checkpoint format
|
| 91 |
+
model.load_state_dict(state["model_state"])
|
| 92 |
+
else: # final model
|
| 93 |
+
model.load_state_dict(state)
|
| 94 |
+
model.eval()
|
| 95 |
+
|
| 96 |
+
# Predict
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
pred = model(flat_in) # (1,8)
|
| 99 |
+
pred = pred.view(4, 2).cpu().numpy()
|
| 100 |
+
|
| 101 |
+
# Convert normalized coords to pixel coords
|
| 102 |
+
pred_px = pred.copy()
|
| 103 |
+
pred_px[:, 0] *= orig_w
|
| 104 |
+
pred_px[:, 1] *= orig_h
|
| 105 |
+
|
| 106 |
+
# Save prediction
|
| 107 |
+
save_prediction_xml(pred_px, out_path, orig_w, orig_h)
|
| 108 |
+
print(f"Saved prediction -> {out_path}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# --------------------
|
| 112 |
+
# Example usage
|
| 113 |
+
# --------------------
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
mockup_json = "Transformer/test/100847_TD/front/LAS02/mockup.json"
|
| 116 |
+
pers_img = "Transformer/test/100847_TD/front/LAS02/4BC13E58-1D8A-4E5D-8A40-C1F4B1248893_visual.jpg"
|
| 117 |
+
model_ckpt = "Transformer/transformer_model.pth"
|
| 118 |
+
predict_one(mockup_json, pers_img, model_ckpt, out_path="Transformer/Prediction/pred3.xml")
|
train.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import glob
|
| 4 |
+
import xml.etree.ElementTree as ET
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from shapely.geometry import Polygon
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# =====================
|
| 16 |
+
# Data Utils
|
| 17 |
+
# # =====================
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import json
|
| 21 |
+
|
| 22 |
+
def flat_corners_from_mockup(mockup_path):
|
| 23 |
+
"""
|
| 24 |
+
Returns 4 corners of print area from mockup.json
|
| 25 |
+
ordered TL, TR, BR, BL and normalized [0,1] w.r.t background.
|
| 26 |
+
"""
|
| 27 |
+
d = json.loads(Path(mockup_path).read_text())
|
| 28 |
+
bg_w = d["background"]["width"]
|
| 29 |
+
bg_h = d["background"]["height"]
|
| 30 |
+
area = d["printAreas"][0]
|
| 31 |
+
x, y = area["position"]["x"], area["position"]["y"]
|
| 32 |
+
w, h = area["width"], area["height"]
|
| 33 |
+
angle = area["rotation"]
|
| 34 |
+
cx, cy = x + w/2.0, y + h/2.0
|
| 35 |
+
|
| 36 |
+
# corners in px (TL,TR,BR,BL)
|
| 37 |
+
dx, dy = w/2.0, h/2.0
|
| 38 |
+
corners = np.array([[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]], dtype=np.float32)
|
| 39 |
+
theta = np.deg2rad(angle)
|
| 40 |
+
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], dtype=np.float32)
|
| 41 |
+
rot = (corners @ R.T) + np.array([cx, cy], dtype=np.float32)
|
| 42 |
+
|
| 43 |
+
# normalize
|
| 44 |
+
norm = np.zeros_like(rot)
|
| 45 |
+
norm[:,0] = rot[:,0] / bg_w
|
| 46 |
+
norm[:,1] = rot[:,1] / bg_h
|
| 47 |
+
return rot.astype(np.float32), norm.astype(np.float32)
|
| 48 |
+
|
| 49 |
+
def parse_xml_points(xml_path):
|
| 50 |
+
"""
|
| 51 |
+
Parse the 4 corner points from the XML (FourPoint transform).
|
| 52 |
+
Returns normalized coordinates (TL, TR, BR, BL).
|
| 53 |
+
"""
|
| 54 |
+
tree = ET.parse(xml_path)
|
| 55 |
+
root = tree.getroot()
|
| 56 |
+
|
| 57 |
+
points = []
|
| 58 |
+
bg_w = int(root.find("background").get("width"))
|
| 59 |
+
bg_h = int(root.find("background").get("height"))
|
| 60 |
+
|
| 61 |
+
for transform in root.findall(".//transform"):
|
| 62 |
+
if transform.get("type") == "FourPoint":
|
| 63 |
+
for pt in ["TopLeft", "TopRight", "BottomRight", "BottomLeft"]:
|
| 64 |
+
node = transform.find(f".//point[@type='{pt}']")
|
| 65 |
+
if node is not None:
|
| 66 |
+
x = float(node.get("x")) / bg_w
|
| 67 |
+
y = float(node.get("y")) / bg_h
|
| 68 |
+
points.append([x, y])
|
| 69 |
+
break # only first transform
|
| 70 |
+
|
| 71 |
+
return np.array(points, dtype=np.float32) # (4,2)
|
| 72 |
+
|
| 73 |
+
class KP4Dataset(Dataset):
|
| 74 |
+
def __init__(self, root, img_size=512):
|
| 75 |
+
self.root = Path(root)
|
| 76 |
+
self.img_size = img_size
|
| 77 |
+
self.samples = []
|
| 78 |
+
|
| 79 |
+
# Transform pipeline (resize + tensor + normalize)
|
| 80 |
+
self.transform = T.Compose([
|
| 81 |
+
T.Resize((img_size, img_size)),
|
| 82 |
+
T.ToTensor(),
|
| 83 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 84 |
+
])
|
| 85 |
+
|
| 86 |
+
# Walk recursively
|
| 87 |
+
for xml_file in self.root.rglob("*.xml"):
|
| 88 |
+
if "_visual" not in xml_file.stem:
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
# Find matching perspective image
|
| 92 |
+
base = xml_file.stem
|
| 93 |
+
img_file = None
|
| 94 |
+
for ext in [".png", ".jpg", ".jpeg"]:
|
| 95 |
+
cand = xml_file.with_suffix(ext)
|
| 96 |
+
if cand.exists():
|
| 97 |
+
img_file = cand
|
| 98 |
+
break
|
| 99 |
+
if img_file is None:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
# Flat image (background)
|
| 103 |
+
flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".png")
|
| 104 |
+
if not flat_img.exists():
|
| 105 |
+
flat_img = xml_file.parent / (base.replace("_visual", "_background") + ".jpg")
|
| 106 |
+
if not flat_img.exists():
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
# Mockup.json
|
| 110 |
+
json_file = xml_file.parent / "mockup.json"
|
| 111 |
+
if not json_file.exists():
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
self.samples.append((img_file, xml_file, flat_img, json_file))
|
| 115 |
+
|
| 116 |
+
if not self.samples:
|
| 117 |
+
raise RuntimeError(f"No valid samples found under {root}")
|
| 118 |
+
|
| 119 |
+
def __len__(self):
|
| 120 |
+
return len(self.samples)
|
| 121 |
+
|
| 122 |
+
def __getitem__(self, idx):
|
| 123 |
+
img_file, xml_file, flat_img, json_file = self.samples[idx]
|
| 124 |
+
|
| 125 |
+
img = self.transform(Image.open(img_file).convert("RGB"))
|
| 126 |
+
flat = self.transform(Image.open(flat_img).convert("RGB"))
|
| 127 |
+
|
| 128 |
+
# flat points
|
| 129 |
+
_, flat_norm = flat_corners_from_mockup(json_file)
|
| 130 |
+
flat_pts = torch.tensor(flat_norm, dtype=torch.float32)
|
| 131 |
+
|
| 132 |
+
# perspective points
|
| 133 |
+
persp_norm = parse_xml_points(xml_file)
|
| 134 |
+
persp_pts = torch.tensor(persp_norm, dtype=torch.float32)
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"persp_img": img,
|
| 138 |
+
"flat_img": flat,
|
| 139 |
+
"flat_pts": flat_pts,
|
| 140 |
+
"persp_pts": persp_pts,
|
| 141 |
+
"xml": str(xml_file),
|
| 142 |
+
"json": str(json_file),
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
# =====================
|
| 146 |
+
# Model
|
| 147 |
+
# =====================
|
| 148 |
+
class SimpleTransformer(nn.Module):
|
| 149 |
+
def __init__(self, d_model=128, nhead=4, num_layers=2):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.fc_in = nn.Linear(8, d_model) # 4 corners * 2
|
| 152 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
|
| 153 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 154 |
+
self.fc_out = nn.Linear(d_model, 8) # predict 4 corners (x,y)*4
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
x = self.fc_in(x).unsqueeze(1) # (B,1,8)->(B,1,d_model)
|
| 158 |
+
x = self.transformer(x)
|
| 159 |
+
x = self.fc_out(x).squeeze(1) # (B,d_model)->(B,8)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# =====================
|
| 164 |
+
# Metrics
|
| 165 |
+
# =====================
|
| 166 |
+
def mse_loss(pred, gt):
|
| 167 |
+
return ((pred-gt)**2).mean()
|
| 168 |
+
|
| 169 |
+
def mean_corner_error(pred, gt, img_w, img_h):
|
| 170 |
+
pred_px = pred * torch.tensor([img_w,img_h], device=pred.device)
|
| 171 |
+
gt_px = gt * torch.tensor([img_w,img_h], device=gt.device)
|
| 172 |
+
err = torch.norm(pred_px-gt_px, dim=-1).mean().item()
|
| 173 |
+
return err
|
| 174 |
+
|
| 175 |
+
def iou_quad(pred, gt):
|
| 176 |
+
pred_poly = Polygon(pred.tolist())
|
| 177 |
+
gt_poly = Polygon(gt.tolist())
|
| 178 |
+
if not pred_poly.is_valid or not gt_poly.is_valid:
|
| 179 |
+
return 0.0
|
| 180 |
+
inter = pred_poly.intersection(gt_poly).area
|
| 181 |
+
union = pred_poly.union(gt_poly).area
|
| 182 |
+
return inter/union if union > 0 else 0.0
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# =====================
|
| 186 |
+
# Training
|
| 187 |
+
# =====================
|
| 188 |
+
def train_model(
|
| 189 |
+
train_root,
|
| 190 |
+
test_root,
|
| 191 |
+
epochs=20,
|
| 192 |
+
batch_size=8,
|
| 193 |
+
lr=1e-3,
|
| 194 |
+
img_size=256,
|
| 195 |
+
save_dir="Transformer/checkpoints",
|
| 196 |
+
resume_path=None
|
| 197 |
+
):
|
| 198 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 199 |
+
|
| 200 |
+
train_ds = KP4Dataset(train_root, img_size=img_size)
|
| 201 |
+
val_ds = KP4Dataset(test_root, img_size=img_size)
|
| 202 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
| 203 |
+
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)
|
| 204 |
+
|
| 205 |
+
model = SimpleTransformer().to(device)
|
| 206 |
+
optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 207 |
+
start_epoch = 0
|
| 208 |
+
|
| 209 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 210 |
+
|
| 211 |
+
# Resume Training
|
| 212 |
+
if resume_path is not None and os.path.exists(resume_path):
|
| 213 |
+
print(f"Loading checkpoint from {resume_path}")
|
| 214 |
+
checkpoint = torch.load(resume_path, map_location=device)
|
| 215 |
+
model.load_state_dict(checkpoint["model_state"])
|
| 216 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
| 217 |
+
start_epoch = checkpoint["epoch"]
|
| 218 |
+
print(f"Resumed from epoch {start_epoch}")
|
| 219 |
+
|
| 220 |
+
# ===================== Track Best Model =====================
|
| 221 |
+
best_iou = -1.0
|
| 222 |
+
best_model_path = os.path.join(save_dir, "best_model.pth")
|
| 223 |
+
|
| 224 |
+
for epoch in range(start_epoch, epochs):
|
| 225 |
+
# -------- Training --------
|
| 226 |
+
model.train()
|
| 227 |
+
total_loss = 0
|
| 228 |
+
for batch in train_loader:
|
| 229 |
+
flat_pts = batch["flat_pts"].to(device)
|
| 230 |
+
persp_pts = batch["persp_pts"].to(device)
|
| 231 |
+
|
| 232 |
+
flat_pts_in = flat_pts.view(flat_pts.size(0), -1)
|
| 233 |
+
target = persp_pts.view(persp_pts.size(0), -1)
|
| 234 |
+
|
| 235 |
+
pred = model(flat_pts_in)
|
| 236 |
+
loss = mse_loss(pred, target)
|
| 237 |
+
|
| 238 |
+
optimizer.zero_grad()
|
| 239 |
+
loss.backward()
|
| 240 |
+
optimizer.step()
|
| 241 |
+
total_loss += loss.item()
|
| 242 |
+
|
| 243 |
+
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {total_loss/len(train_loader):.6f}")
|
| 244 |
+
|
| 245 |
+
# -------- Validation --------
|
| 246 |
+
model.eval()
|
| 247 |
+
mse_all, ce_all, iou_all = [], [], []
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
for batch in val_loader:
|
| 250 |
+
flat_pts = batch["flat_pts"].to(device)
|
| 251 |
+
persp_pts = batch["persp_pts"].to(device)
|
| 252 |
+
|
| 253 |
+
flat_pts_in = flat_pts.view(1, -1)
|
| 254 |
+
target = persp_pts.view(1, -1)
|
| 255 |
+
|
| 256 |
+
pred = model(flat_pts_in)
|
| 257 |
+
mse_all.append(mse_loss(pred, target).item())
|
| 258 |
+
|
| 259 |
+
pred_quad = pred.view(4,2).cpu()
|
| 260 |
+
gt_quad = persp_pts.view(4,2).cpu()
|
| 261 |
+
|
| 262 |
+
w,h = batch["persp_img"].shape[2], batch["persp_img"].shape[1]
|
| 263 |
+
ce_all.append(mean_corner_error(pred_quad, gt_quad, w, h))
|
| 264 |
+
iou_all.append(iou_quad(pred_quad, gt_quad))
|
| 265 |
+
|
| 266 |
+
val_mse = np.mean(mse_all)
|
| 267 |
+
val_ce = np.mean(ce_all)
|
| 268 |
+
val_iou = np.mean(iou_all)
|
| 269 |
+
|
| 270 |
+
print(f" Val MSE: {val_mse:.6f}, CornerErr(px): {val_ce:.2f}, IoU: {val_iou:.3f}")
|
| 271 |
+
if (epoch + 1) % 100 == 0:
|
| 272 |
+
# -------- Save Epoch Checkpoint (like before) --------
|
| 273 |
+
checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth")
|
| 274 |
+
torch.save({
|
| 275 |
+
"epoch": epoch+1,
|
| 276 |
+
"model_state": model.state_dict(),
|
| 277 |
+
"optimizer_state": optimizer.state_dict(),
|
| 278 |
+
"val_iou": val_iou,
|
| 279 |
+
}, checkpoint_path)
|
| 280 |
+
print(f"Checkpoint saved: {checkpoint_path}")
|
| 281 |
+
|
| 282 |
+
# -------- Save Best Model --------
|
| 283 |
+
if val_iou > best_iou:
|
| 284 |
+
best_iou = val_iou
|
| 285 |
+
torch.save({
|
| 286 |
+
"epoch": epoch+1,
|
| 287 |
+
"model_state": model.state_dict(),
|
| 288 |
+
"optimizer_state": optimizer.state_dict(),
|
| 289 |
+
"best_iou": best_iou,
|
| 290 |
+
}, best_model_path)
|
| 291 |
+
print(f"Best model updated at epoch {epoch+1} (IoU={val_iou:.3f})")
|
| 292 |
+
|
| 293 |
+
# Save final model weights
|
| 294 |
+
final_path = os.path.join(save_dir, "final_model.pth")
|
| 295 |
+
torch.save(model.state_dict(), final_path)
|
| 296 |
+
print(f"Final model saved at {final_path}")
|
| 297 |
+
print(f"Best model saved at {best_model_path} with IoU={best_iou:.3f}")
|
| 298 |
+
|
| 299 |
+
return model
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# =====================
|
| 303 |
+
# Main
|
| 304 |
+
# =====================
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
model = train_model(
|
| 307 |
+
train_root="Transformer/train",
|
| 308 |
+
test_root="Transformer/test",
|
| 309 |
+
epochs=3000,
|
| 310 |
+
batch_size=4,
|
| 311 |
+
lr=1e-3,
|
| 312 |
+
img_size=256,
|
| 313 |
+
resume_path=None
|
| 314 |
+
)
|