koesan's picture
Update app.py
612e06a verified
import os
import io
import base64
import torch
import numpy as np
from PIL import Image, ImageDraw
from flask import Flask, render_template, request, jsonify
# Model imports
from model.layoutganpp import Generator
from util import set_seed
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "model_best.pth.tar"
app = Flask(__name__)
# Load model
def load_model():
"""Load pretrained LayoutGAN++ model"""
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
args = checkpoint['args']
# Get num_label from checkpoint
num_label = checkpoint['netG']['emb_label.weight'].shape[0]
# Initialize model
model = Generator(
args['latent_size'],
num_label,
d_model=args['G_d_model'],
nhead=args['G_nhead'],
num_layers=args['G_num_layers']
).to(DEVICE)
model.load_state_dict(checkpoint['netG'])
model.eval()
return model, args
print("Loading model...")
model, model_args = load_model()
print("Model loaded successfully!")
def convert_layout_to_image(bbox, canvas_size=(512, 512)):
"""Convert bounding boxes to visualization image"""
W, H = canvas_size
img = Image.new('RGB', (W, H), color=(255, 255, 255))
draw = ImageDraw.Draw(img, 'RGBA')
colors = [
(255, 100, 100, 180), (100, 255, 100, 180), (100, 100, 255, 180),
(255, 255, 100, 180), (255, 100, 255, 180), (100, 255, 255, 180),
(255, 150, 100, 180), (150, 100, 255, 180), (100, 255, 150, 180),
(255, 100, 150, 180)
]
# Sort by area
areas = [(b[2] - b[0]) * (b[3] - b[1]) for b in bbox]
indices = sorted(range(len(areas)), key=lambda i: areas[i], reverse=True)
for idx, i in enumerate(indices):
x1, y1, x2, y2 = bbox[i]
x1, y1, x2, y2 = int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H)
color = colors[idx % len(colors)]
draw.rectangle([x1, y1, x2, y2], fill=color, outline=(0, 0, 0), width=2)
text = f"Panel {idx + 1}"
text_bbox = draw.textbbox((0, 0), text)
text_w = text_bbox[2] - text_bbox[0]
text_h = text_bbox[3] - text_bbox[1]
text_x = x1 + (x2 - x1 - text_w) // 2
text_y = y1 + (y2 - y1 - text_h) // 2
draw.text((text_x, text_y), text, fill=(0, 0, 0))
return img
def xywh_to_ltrb(bbox):
"""Convert from center format to corners"""
xc, yc, w, h = bbox
return [xc - w/2, yc - h/2, xc + w/2, yc + h/2]
def ltrb_to_xywh(ltrb):
"""Convert from corners to center format"""
x1, y1, x2, y2 = ltrb
xc = (x1 + x2) / 2
yc = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return [xc, yc, w, h]
def check_overlap(box1, box2, margin=0.01):
"""Check if two boxes overlap"""
x1_1, y1_1, x2_1, y2_1 = xywh_to_ltrb(box1)
x1_2, y1_2, x2_2, y2_2 = xywh_to_ltrb(box2)
if x2_1 + margin < x1_2 or x2_2 + margin < x1_1 or \
y2_1 + margin < y1_2 or y2_2 + margin < y1_1:
return False
return True
def fix_overlaps(boxes, max_attempts=200, move_step=0.005, min_dist=0.015):
"""Fix overlapping boxes by moving them apart"""
boxes = np.array(boxes, dtype=float).copy()
n_boxes = len(boxes)
fixed_all = False
for attempt in range(max_attempts):
overlap_found_in_iteration = False
for i in range(n_boxes):
for j in range(i + 1, n_boxes):
if check_overlap(boxes[i], boxes[j], margin=min_dist):
overlap_found_in_iteration = True
# Calculate direction vector
c_i_x, c_i_y = boxes[i][0], boxes[i][1]
c_j_x, c_j_y = boxes[j][0], boxes[j][1]
vec_x, vec_y = c_i_x - c_j_x, c_i_y - c_j_y
dist = np.sqrt(vec_x**2 + vec_y**2)
if dist < 1e-6:
vec_x, dist = 0.01, 0.01
unit_vec_x, unit_vec_y = vec_x / dist, vec_y / dist
# Move boxes apart
boxes[i][0] += unit_vec_x * move_step
boxes[i][1] += unit_vec_y * move_step
boxes[j][0] -= unit_vec_x * move_step
boxes[j][1] -= unit_vec_y * move_step
# Keep boxes within bounds
boxes[i][0] = np.clip(boxes[i][0], boxes[i][2]/2, 1 - boxes[i][2]/2)
boxes[i][1] = np.clip(boxes[i][1], boxes[i][3]/2, 1 - boxes[i][3]/2)
boxes[j][0] = np.clip(boxes[j][0], boxes[j][2]/2, 1 - boxes[j][2]/2)
boxes[j][1] = np.clip(boxes[j][1], boxes[j][3]/2, 1 - boxes[j][3]/2)
if not overlap_found_in_iteration:
fixed_all = True
break
return boxes.tolist(), fixed_all
@app.route('/')
def index():
return render_template('index.html')
@app.route('/generate', methods=['POST'])
def generate():
try:
num_panels = int(request.json.get('num_panels', 3))
num_panels = max(1, min(10, num_panels))
# Create input
z = torch.randn(1, num_panels, model_args['latent_size'], device=DEVICE)
label = torch.zeros(1, num_panels, dtype=torch.long, device=DEVICE)
padding_mask = torch.zeros(1, num_panels, dtype=torch.bool, device=DEVICE)
# Generate layout
with torch.no_grad():
bbox = model(z, label, padding_mask)
bbox_np = bbox[0].cpu().numpy()
# Fix overlaps
fixed_boxes_xywh, overlap_fixed = fix_overlaps(bbox_np, max_attempts=200)
# Convert to ltrb and clip
bbox_ltrb = []
for box_xywh in fixed_boxes_xywh:
box_ltrb = xywh_to_ltrb(box_xywh)
# Clip to valid range
box_ltrb_clipped = [np.clip(coord, 0.0, 1.0) for coord in box_ltrb]
# Ensure valid box (x2 > x1, y2 > y1)
if box_ltrb_clipped[0] >= box_ltrb_clipped[2]:
box_ltrb_clipped[2] = box_ltrb_clipped[0] + 0.001
if box_ltrb_clipped[1] >= box_ltrb_clipped[3]:
box_ltrb_clipped[3] = box_ltrb_clipped[1] + 0.001
box_ltrb_clipped[2] = np.clip(box_ltrb_clipped[2], box_ltrb_clipped[0], 1.0)
box_ltrb_clipped[3] = np.clip(box_ltrb_clipped[3], box_ltrb_clipped[1], 1.0)
bbox_ltrb.append(box_ltrb_clipped)
# Create image
img = convert_layout_to_image(bbox_ltrb, canvas_size=(512, 512))
# Convert to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return jsonify({'success': True, 'image': img_str})
except Exception as e:
return jsonify({'success': False, 'error': str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)