Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import base64 | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from flask import Flask, render_template, request, jsonify | |
| # Model imports | |
| from model.layoutganpp import Generator | |
| from util import set_seed | |
| # Configuration | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_PATH = "model_best.pth.tar" | |
| app = Flask(__name__) | |
| # Load model | |
| def load_model(): | |
| """Load pretrained LayoutGAN++ model""" | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file not found: {MODEL_PATH}") | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| args = checkpoint['args'] | |
| # Get num_label from checkpoint | |
| num_label = checkpoint['netG']['emb_label.weight'].shape[0] | |
| # Initialize model | |
| model = Generator( | |
| args['latent_size'], | |
| num_label, | |
| d_model=args['G_d_model'], | |
| nhead=args['G_nhead'], | |
| num_layers=args['G_num_layers'] | |
| ).to(DEVICE) | |
| model.load_state_dict(checkpoint['netG']) | |
| model.eval() | |
| return model, args | |
| print("Loading model...") | |
| model, model_args = load_model() | |
| print("Model loaded successfully!") | |
| def convert_layout_to_image(bbox, canvas_size=(512, 512)): | |
| """Convert bounding boxes to visualization image""" | |
| W, H = canvas_size | |
| img = Image.new('RGB', (W, H), color=(255, 255, 255)) | |
| draw = ImageDraw.Draw(img, 'RGBA') | |
| colors = [ | |
| (255, 100, 100, 180), (100, 255, 100, 180), (100, 100, 255, 180), | |
| (255, 255, 100, 180), (255, 100, 255, 180), (100, 255, 255, 180), | |
| (255, 150, 100, 180), (150, 100, 255, 180), (100, 255, 150, 180), | |
| (255, 100, 150, 180) | |
| ] | |
| # Sort by area | |
| areas = [(b[2] - b[0]) * (b[3] - b[1]) for b in bbox] | |
| indices = sorted(range(len(areas)), key=lambda i: areas[i], reverse=True) | |
| for idx, i in enumerate(indices): | |
| x1, y1, x2, y2 = bbox[i] | |
| x1, y1, x2, y2 = int(x1 * W), int(y1 * H), int(x2 * W), int(y2 * H) | |
| color = colors[idx % len(colors)] | |
| draw.rectangle([x1, y1, x2, y2], fill=color, outline=(0, 0, 0), width=2) | |
| text = f"Panel {idx + 1}" | |
| text_bbox = draw.textbbox((0, 0), text) | |
| text_w = text_bbox[2] - text_bbox[0] | |
| text_h = text_bbox[3] - text_bbox[1] | |
| text_x = x1 + (x2 - x1 - text_w) // 2 | |
| text_y = y1 + (y2 - y1 - text_h) // 2 | |
| draw.text((text_x, text_y), text, fill=(0, 0, 0)) | |
| return img | |
| def xywh_to_ltrb(bbox): | |
| """Convert from center format to corners""" | |
| xc, yc, w, h = bbox | |
| return [xc - w/2, yc - h/2, xc + w/2, yc + h/2] | |
| def ltrb_to_xywh(ltrb): | |
| """Convert from corners to center format""" | |
| x1, y1, x2, y2 = ltrb | |
| xc = (x1 + x2) / 2 | |
| yc = (y1 + y2) / 2 | |
| w = x2 - x1 | |
| h = y2 - y1 | |
| return [xc, yc, w, h] | |
| def check_overlap(box1, box2, margin=0.01): | |
| """Check if two boxes overlap""" | |
| x1_1, y1_1, x2_1, y2_1 = xywh_to_ltrb(box1) | |
| x1_2, y1_2, x2_2, y2_2 = xywh_to_ltrb(box2) | |
| if x2_1 + margin < x1_2 or x2_2 + margin < x1_1 or \ | |
| y2_1 + margin < y1_2 or y2_2 + margin < y1_1: | |
| return False | |
| return True | |
| def fix_overlaps(boxes, max_attempts=200, move_step=0.005, min_dist=0.015): | |
| """Fix overlapping boxes by moving them apart""" | |
| boxes = np.array(boxes, dtype=float).copy() | |
| n_boxes = len(boxes) | |
| fixed_all = False | |
| for attempt in range(max_attempts): | |
| overlap_found_in_iteration = False | |
| for i in range(n_boxes): | |
| for j in range(i + 1, n_boxes): | |
| if check_overlap(boxes[i], boxes[j], margin=min_dist): | |
| overlap_found_in_iteration = True | |
| # Calculate direction vector | |
| c_i_x, c_i_y = boxes[i][0], boxes[i][1] | |
| c_j_x, c_j_y = boxes[j][0], boxes[j][1] | |
| vec_x, vec_y = c_i_x - c_j_x, c_i_y - c_j_y | |
| dist = np.sqrt(vec_x**2 + vec_y**2) | |
| if dist < 1e-6: | |
| vec_x, dist = 0.01, 0.01 | |
| unit_vec_x, unit_vec_y = vec_x / dist, vec_y / dist | |
| # Move boxes apart | |
| boxes[i][0] += unit_vec_x * move_step | |
| boxes[i][1] += unit_vec_y * move_step | |
| boxes[j][0] -= unit_vec_x * move_step | |
| boxes[j][1] -= unit_vec_y * move_step | |
| # Keep boxes within bounds | |
| boxes[i][0] = np.clip(boxes[i][0], boxes[i][2]/2, 1 - boxes[i][2]/2) | |
| boxes[i][1] = np.clip(boxes[i][1], boxes[i][3]/2, 1 - boxes[i][3]/2) | |
| boxes[j][0] = np.clip(boxes[j][0], boxes[j][2]/2, 1 - boxes[j][2]/2) | |
| boxes[j][1] = np.clip(boxes[j][1], boxes[j][3]/2, 1 - boxes[j][3]/2) | |
| if not overlap_found_in_iteration: | |
| fixed_all = True | |
| break | |
| return boxes.tolist(), fixed_all | |
| def index(): | |
| return render_template('index.html') | |
| def generate(): | |
| try: | |
| num_panels = int(request.json.get('num_panels', 3)) | |
| num_panels = max(1, min(10, num_panels)) | |
| # Create input | |
| z = torch.randn(1, num_panels, model_args['latent_size'], device=DEVICE) | |
| label = torch.zeros(1, num_panels, dtype=torch.long, device=DEVICE) | |
| padding_mask = torch.zeros(1, num_panels, dtype=torch.bool, device=DEVICE) | |
| # Generate layout | |
| with torch.no_grad(): | |
| bbox = model(z, label, padding_mask) | |
| bbox_np = bbox[0].cpu().numpy() | |
| # Fix overlaps | |
| fixed_boxes_xywh, overlap_fixed = fix_overlaps(bbox_np, max_attempts=200) | |
| # Convert to ltrb and clip | |
| bbox_ltrb = [] | |
| for box_xywh in fixed_boxes_xywh: | |
| box_ltrb = xywh_to_ltrb(box_xywh) | |
| # Clip to valid range | |
| box_ltrb_clipped = [np.clip(coord, 0.0, 1.0) for coord in box_ltrb] | |
| # Ensure valid box (x2 > x1, y2 > y1) | |
| if box_ltrb_clipped[0] >= box_ltrb_clipped[2]: | |
| box_ltrb_clipped[2] = box_ltrb_clipped[0] + 0.001 | |
| if box_ltrb_clipped[1] >= box_ltrb_clipped[3]: | |
| box_ltrb_clipped[3] = box_ltrb_clipped[1] + 0.001 | |
| box_ltrb_clipped[2] = np.clip(box_ltrb_clipped[2], box_ltrb_clipped[0], 1.0) | |
| box_ltrb_clipped[3] = np.clip(box_ltrb_clipped[3], box_ltrb_clipped[1], 1.0) | |
| bbox_ltrb.append(box_ltrb_clipped) | |
| # Create image | |
| img = convert_layout_to_image(bbox_ltrb, canvas_size=(512, 512)) | |
| # Convert to base64 | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return jsonify({'success': True, 'image': img_str}) | |
| except Exception as e: | |
| return jsonify({'success': False, 'error': str(e)}) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) | |