saim1309's picture
Upload 6 files
ec7f44e verified
raw
history blame
4.26 kB
import torch
import xml.etree.ElementTree as ET
from PIL import Image
import numpy as np
from pathlib import Path
import json
from train import SimpleTransformer, flat_corners_from_mockup
# --------------------
# Utility: order 4 points (same as old)
# --------------------
def order_points_clockwise(pts):
pts = np.array(pts, dtype="float32")
y_sorted = pts[np.argsort(pts[:, 1]), :]
top_two = y_sorted[:2, :]
bottom_two = y_sorted[2:, :]
if top_two[0][0] < top_two[1][0]:
tl, tr = top_two
else:
tr, tl = top_two
if bottom_two[0][0] < bottom_two[1][0]:
bl, br = bottom_two
else:
br, bl = bottom_two
return np.array([tl, tr, br, bl], dtype="float32")
# --------------------
# Utility: save XML prediction
# --------------------
def save_prediction_xml(pred_pts, out_path, img_w, img_h):
ordered = order_points_clockwise(pred_pts)
TL, TR, BR, BL = ordered
root = ET.Element("visualization", version="1.0")
ET.SubElement(root, "effects", surfacecolor="", iswood="0")
ET.SubElement(root, "background",
width=str(img_w), height=str(img_h),
color1="#C4CDE4", color2="", color3="")
transforms_node = ET.SubElement(root, "transforms")
transform = ET.SubElement(transforms_node, "transform",
type="FourPoint", offsetX="0", offsetY="0", offsetZ="0.0",
rotationX="0.0", rotationY="0.0", rotationZ="0.0",
name="Region", posCode="REGION", posName="Region",
posDef="0", techCode="EMBF03", techName="Embroidery Fixed",
techDef="0", areaWidth="100", areaHeight="100",
maxColors="12", defaultLogoSize="100", sizeX="100", sizeY="100")
pts = {"TopLeft": TL, "TopRight": TR, "BottomRight": BR, "BottomLeft": BL}
for ptype, (x, y) in pts.items():
ET.SubElement(transform, "point",
type=ptype, x=str(float(x)), y=str(float(y)),
z="0.0", warp="0", warpShift="0")
overlays = ET.SubElement(root, "overlays")
overlay = ET.SubElement(overlays, "overlay")
for (x, y) in ordered:
ET.SubElement(overlay, "point", type="Next", x=str(float(x)), y=str(float(y)), z="0.0")
ET.SubElement(root, "ruler",
startX=str(TL[0]), startY=str(TL[1]),
stopX=str(BR[0]), stopY=str(BR[1]), value="100")
tree = ET.ElementTree(root)
tree.write(out_path, encoding="utf-8", xml_declaration=True)
# --------------------
# Predict one sample
# --------------------
def predict_one(mockup_json, pers_img_path, model_ckpt, out_path="prediction.xml"):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load perspective image
pers_img = Image.open(pers_img_path).convert("RGB")
orig_w, orig_h = pers_img.size
# Load flat points from mockup.json
_, flat_norm = flat_corners_from_mockup(mockup_json)
flat_in = torch.tensor(flat_norm.flatten(), dtype=torch.float32).unsqueeze(0).to(device) # (1,8)
# Load model
model = SimpleTransformer().to(device)
state = torch.load(model_ckpt, map_location=device, weights_only=False)
if "model_state" in state: # resume checkpoint format
model.load_state_dict(state["model_state"])
else: # final model
model.load_state_dict(state)
model.eval()
# Predict
with torch.no_grad():
pred = model(flat_in) # (1,8)
pred = pred.view(4, 2).cpu().numpy()
# Convert normalized coords to pixel coords
pred_px = pred.copy()
pred_px[:, 0] *= orig_w
pred_px[:, 1] *= orig_h
# Save prediction
save_prediction_xml(pred_px, out_path, orig_w, orig_h)
print(f"Saved prediction -> {out_path}")
# --------------------
# Example usage
# --------------------
if __name__ == "__main__":
mockup_json = "Transformer/test/100847_TD/front/LAS02/mockup.json"
pers_img = "Transformer/test/100847_TD/front/LAS02/4BC13E58-1D8A-4E5D-8A40-C1F4B1248893_visual.jpg"
model_ckpt = "Transformer/transformer_model.pth"
predict_one(mockup_json, pers_img, model_ckpt, out_path="Transformer/Prediction/pred3.xml")