import os import io import base64 import torch import numpy as np from PIL import Image, ImageDraw from flask import Flask, render_template, request, jsonify # Model imports from model.layoutganpp import Generator from util import set_seed # Configuration DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "model_best.pth.tar" app = Flask(__name__) # Load model def load_model(): """Load pretrained LayoutGAN++ model""" if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) args = checkpoint['args'] # Get num_label from checkpoint num_label = checkpoint['netG']['emb_label.weight'].shape[0] # Initialize model model = Generator( args['latent_size'], num_label, d_model=args['G_d_model'], nhead=args['G_nhead'], num_layers=args['G_num_layers'] ).to(DEVICE) model.load_state_dict(checkpoint['netG']) model.eval() return model, args print("Loading model...") model, model_args = load_model() print("Model loaded successfully!") def convert_layout_to_image(bbox, canvas_size=(512, 512)): """Convert bounding boxes to visualization image""" W, H = canvas_size img = Image.new('RGB', (W, H), color=(255, 255, 255)) draw = ImageDraw.Draw(img, 'RGBA') colors = [ (255, 100, 100, 180), (100, 255, 100, 180), (100, 100, 255, 180), (255, 255, 100, 180), (255, 100, 255, 180), (100, 255, 255, 180), (255, 150, 100, 180), (150, 100, 255, 180), (100, 255, 150, 180), (255, 100, 150, 180) ] # Sort by area areas = [(b[2] - b[0]) * (b[3] - b[1]) for b in bbox] indices = sorted(range(len(areas)), key=lambda i: areas[i], reverse=True) for idx, i in enumerate(indices): x1, y1, x2, y2 = bbox[i] x1, y1, x2, y2 = int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H) color = colors[idx % len(colors)] draw.rectangle([x1, y1, x2, y2], fill=color, outline=(0, 0, 0), width=2) text = f"Panel {idx + 1}" text_bbox = draw.textbbox((0, 0), text) text_w = text_bbox[2] - text_bbox[0] text_h = text_bbox[3] - text_bbox[1] text_x = x1 + (x2 - x1 - text_w) // 2 text_y = y1 + (y2 - y1 - text_h) // 2 draw.text((text_x, text_y), text, fill=(0, 0, 0)) return img def xywh_to_ltrb(bbox): """Convert from center format to corners""" xc, yc, w, h = bbox return [xc - w/2, yc - h/2, xc + w/2, yc + h/2] def ltrb_to_xywh(ltrb): """Convert from corners to center format""" x1, y1, x2, y2 = ltrb xc = (x1 + x2) / 2 yc = (y1 + y2) / 2 w = x2 - x1 h = y2 - y1 return [xc, yc, w, h] def check_overlap(box1, box2, margin=0.01): """Check if two boxes overlap""" x1_1, y1_1, x2_1, y2_1 = xywh_to_ltrb(box1) x1_2, y1_2, x2_2, y2_2 = xywh_to_ltrb(box2) if x2_1 + margin < x1_2 or x2_2 + margin < x1_1 or \ y2_1 + margin < y1_2 or y2_2 + margin < y1_1: return False return True def fix_overlaps(boxes, max_attempts=200, move_step=0.005, min_dist=0.015): """Fix overlapping boxes by moving them apart""" boxes = np.array(boxes, dtype=float).copy() n_boxes = len(boxes) fixed_all = False for attempt in range(max_attempts): overlap_found_in_iteration = False for i in range(n_boxes): for j in range(i + 1, n_boxes): if check_overlap(boxes[i], boxes[j], margin=min_dist): overlap_found_in_iteration = True # Calculate direction vector c_i_x, c_i_y = boxes[i][0], boxes[i][1] c_j_x, c_j_y = boxes[j][0], boxes[j][1] vec_x, vec_y = c_i_x - c_j_x, c_i_y - c_j_y dist = np.sqrt(vec_x**2 + vec_y**2) if dist < 1e-6: vec_x, dist = 0.01, 0.01 unit_vec_x, unit_vec_y = vec_x / dist, vec_y / dist # Move boxes apart boxes[i][0] += unit_vec_x * move_step boxes[i][1] += unit_vec_y * move_step boxes[j][0] -= unit_vec_x * move_step boxes[j][1] -= unit_vec_y * move_step # Keep boxes within bounds boxes[i][0] = np.clip(boxes[i][0], boxes[i][2]/2, 1 - boxes[i][2]/2) boxes[i][1] = np.clip(boxes[i][1], boxes[i][3]/2, 1 - boxes[i][3]/2) boxes[j][0] = np.clip(boxes[j][0], boxes[j][2]/2, 1 - boxes[j][2]/2) boxes[j][1] = np.clip(boxes[j][1], boxes[j][3]/2, 1 - boxes[j][3]/2) if not overlap_found_in_iteration: fixed_all = True break return boxes.tolist(), fixed_all @app.route('/') def index(): return render_template('index.html') @app.route('/generate', methods=['POST']) def generate(): try: num_panels = int(request.json.get('num_panels', 3)) num_panels = max(1, min(10, num_panels)) # Create input z = torch.randn(1, num_panels, model_args['latent_size'], device=DEVICE) label = torch.zeros(1, num_panels, dtype=torch.long, device=DEVICE) padding_mask = torch.zeros(1, num_panels, dtype=torch.bool, device=DEVICE) # Generate layout with torch.no_grad(): bbox = model(z, label, padding_mask) bbox_np = bbox[0].cpu().numpy() # Fix overlaps fixed_boxes_xywh, overlap_fixed = fix_overlaps(bbox_np, max_attempts=200) # Convert to ltrb and clip bbox_ltrb = [] for box_xywh in fixed_boxes_xywh: box_ltrb = xywh_to_ltrb(box_xywh) # Clip to valid range box_ltrb_clipped = [np.clip(coord, 0.0, 1.0) for coord in box_ltrb] # Ensure valid box (x2 > x1, y2 > y1) if box_ltrb_clipped[0] >= box_ltrb_clipped[2]: box_ltrb_clipped[2] = box_ltrb_clipped[0] + 0.001 if box_ltrb_clipped[1] >= box_ltrb_clipped[3]: box_ltrb_clipped[3] = box_ltrb_clipped[1] + 0.001 box_ltrb_clipped[2] = np.clip(box_ltrb_clipped[2], box_ltrb_clipped[0], 1.0) box_ltrb_clipped[3] = np.clip(box_ltrb_clipped[3], box_ltrb_clipped[1], 1.0) bbox_ltrb.append(box_ltrb_clipped) # Create image img = convert_layout_to_image(bbox_ltrb, canvas_size=(512, 512)) # Convert to base64 buffered = io.BytesIO() img.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() return jsonify({'success': True, 'image': img_str}) except Exception as e: return jsonify({'success': False, 'error': str(e)}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)