| | from pathlib import Path |
| | from num2words import num2words |
| | import numpy as np |
| | import os |
| | import random |
| | import re |
| | import torch |
| | import json |
| | from shapely.geometry.polygon import Polygon |
| | from shapely.affinity import scale |
| | from PIL import Image, ImageDraw, ImageOps, ImageFilter, ImageFont, ImageColor |
| |
|
| | os.system('pip install gradio==2.7.5') |
| | import gradio as gr |
| |
|
| | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM |
| |
|
| | finetuned = AutoModelForCausalLM.from_pretrained('model') |
| | tokenizer = AutoTokenizer.from_pretrained('gpt2') |
| |
|
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | print(device) |
| | finetuned = finetuned.to(device) |
| |
|
| | |
| |
|
| | def containsNumber(value): |
| | for character in value: |
| | if character.isdigit(): |
| | return True |
| | return False |
| | |
| | def creativity(intensity): |
| | if(intensity == 'Low'): |
| | top_p = 0.95 |
| | top_k = 10 |
| | elif(intensity == 'Medium'): |
| | top_p = 0.9 |
| | top_k = 50 |
| | if(intensity == 'High'): |
| | top_p = 0.85 |
| | top_k = 100 |
| | return top_p, top_k |
| |
|
| | housegan_labels = {"living_room": 1, "kitchen": 2, "bedroom": 3, "bathroom": 4, "missing": 5, "closet": 6, |
| | "balcony": 7, "corridor": 8, "dining_room": 9, "laundry_room": 10} |
| |
|
| | architext_colors = [[0, 0, 0], [249, 222, 182], [195, 209, 217], [250, 120, 128], [126, 202, 234], [190, 0, 198], [255, 255, 255], |
| | [6, 53, 17], [17, 33, 58], [132, 151, 246], [197, 203, 159], [6, 53, 17],] |
| |
|
| | regex = re.compile(".*?\((.*?)\)") |
| |
|
| | def draw_polygons(polygons, colors, im_size=(512, 512), b_color="white", fpath=None): |
| | image = Image.new("RGBA", im_size, color="white") |
| | draw = ImageDraw.Draw(image) |
| | for poly, color, in zip(polygons, colors): |
| | |
| | xy = poly.exterior.xy |
| | coords = np.dstack((xy[1], xy[0])).flatten() |
| | |
| | draw.polygon(list(coords), fill=(0, 0, 0)) |
| | |
| | small_poly = poly.buffer(-1, resolution=32, cap_style=2, join_style=2, mitre_limit=5.0) |
| | if small_poly.geom_type == 'MultiPolygon': |
| | mycoordslist = [list(x.exterior.coords) for x in small_poly] |
| | for coord in mycoordslist: |
| | coords = np.dstack((np.array(coord)[:,1], np.array(coord)[:, 0])).flatten() |
| | draw.polygon(list(coords), fill=tuple(color)) |
| | elif poly.geom_type == 'Polygon': |
| | |
| | xy2 = small_poly.exterior.xy |
| | coords2 = np.dstack((xy2[1], xy2[0])).flatten() |
| | |
| | draw.polygon(list(coords2), fill=tuple(color)) |
| | image = image.transpose(Image.FLIP_TOP_BOTTOM) |
| | if(fpath): |
| | image.save(fpath, quality=100, subsampling=0) |
| | return draw, image |
| |
|
| | def prompt_to_layout(user_prompt, intensity, fpath=None): |
| | if(containsNumber(user_prompt) == True): |
| | spaced_prompt = user_prompt.split(' ') |
| | new_prompt = ' '.join([word if word.isdigit() == False else num2words(int(word)).lower() for word in spaced_prompt]) |
| | model_prompt = '[User prompt] {} [Layout]'.format(new_prompt) |
| | top_p, top_k = creativity(intensity) |
| | model_prompt = '[User prompt] {} [Layout]'.format(user_prompt) |
| | input_ids = tokenizer(model_prompt, return_tensors='pt').to(device) |
| | output = finetuned.generate(**input_ids, do_sample=True, top_p=top_p, top_k=top_k, |
| | eos_token_id=50256, max_length=400) |
| | output = tokenizer.batch_decode(output, skip_special_tokens=True) |
| | layout = output[0].split('[User prompt]')[1].split('[Layout] ')[1].split(', ') |
| | spaces = [txt.split(':')[0] for txt in layout] |
| | coords = [txt.split(':')[1].rstrip() for txt in layout] |
| | coordinates = [re.findall(regex, coord) for coord in coords] |
| | |
| | num_coords = [] |
| | for coord in coordinates: |
| | temp = [] |
| | for xy in coord: |
| | numbers = xy.split(',') |
| | temp.append(tuple([int(num)/14.2 for num in numbers])) |
| | num_coords.append(temp) |
| | |
| | new_spaces = [] |
| | for i, v in enumerate(spaces): |
| | totalcount = spaces.count(v) |
| | count = spaces[:i].count(v) |
| | new_spaces.append(v + str(count + 1) if totalcount > 1 else v) |
| | |
| | out_dict = dict(zip(new_spaces, num_coords)) |
| | out_dict = json.dumps(out_dict) |
| | |
| | polygons = [] |
| | for coord in coordinates: |
| | polygons.append([point.split(',') for point in coord]) |
| | geom = [] |
| | for poly in polygons: |
| | scaled_poly = scale(Polygon(np.array(poly, dtype=int)), xfact=2, yfact=2, origin=(0,0)) |
| | geom.append(scaled_poly) |
| | colors = [architext_colors[housegan_labels[space]] for space in spaces] |
| | _, im = draw_polygons(geom, colors, fpath=fpath) |
| | html = '<img class="labels" src="images/labels.png" />' |
| | legend = Image.open("labels.png") |
| | imgs_comb = np.vstack([im, legend]) |
| | imgs_comb = Image.fromarray(imgs_comb) |
| | return imgs_comb, out_dict |
| | |
| | |
| | |
| |
|
| | custom_css=""" |
| | @import url("https://use.typekit.net/nid3pfr.css"); |
| | .gradio_wrapper .gradio_bg[is_embedded=false] { |
| | min-height: 80%; |
| | } |
| | |
| | .gradio_wrapper .gradio_bg[is_embedded=false] .gradio_page { |
| | display: flex; |
| | width: 100vw; |
| | min-height: 50vh; |
| | flex-direction: column; |
| | justify-content: center; |
| | align-items: center; |
| | margin: 0px; |
| | max-width: 100vw; |
| | background: #FFFFFF; |
| | } |
| | |
| | .gradio_wrapper .gradio_bg[is_embedded=false] .gradio_page .content { |
| | padding: 0px; |
| | margin: 0px; |
| | } |
| | |
| | .gradio_interface { |
| | width: 100vw; |
| | max-width: 1500px; |
| | } |
| | |
| | .gradio_interface .panel:nth-child(2) .component:nth-child(3) { |
| | display:none |
| | } |
| | |
| | .gradio_wrapper .gradio_bg[theme=default] .panel_buttons { |
| | justify-content: flex-end; |
| | } |
| | |
| | .gradio_wrapper .gradio_bg[theme=default] .panel_button { |
| | flex: 0 0 0; |
| | min-width: 150px; |
| | } |
| | |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .panel_button.submit { |
| | background: #11213A; |
| | border-radius: 5px; |
| | color: #FFFFFF; |
| | text-transform: uppercase; |
| | min-width: 150px; |
| | height: 4em; |
| | letter-spacing: 0.15em; |
| | flex: 0 0 0; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .panel_button.submit:hover { |
| | background: #000000; |
| | } |
| | |
| | .input_text:focus { |
| | border-color: #FA7880; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_text input, |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_text textarea { |
| | font: 200 45px garamond-premier-pro-display, serif; |
| | line-height: 110%; |
| | color: #11213A; |
| | border-radius: 5px; |
| | padding: 15px; |
| | border: none; |
| | background: #F2F4F4; |
| | } |
| | .input_text textarea:focus-visible { |
| | outline: none; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_radio .radio_item.selected { |
| | background-color: #11213A; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .input_radio .selected .radio_circle { |
| | border-color: #4365c4; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .output_image { |
| | width: 100%; |
| | height: 40vw; |
| | max-height: 630px; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .output_image .image_preview_holder { |
| | background: transparent; |
| | } |
| | .panel:nth-child(1) { |
| | margin-left: 50px; |
| | margin-right: 50px; |
| | margin-bottom: 80px; |
| | max-width: 750px; |
| | } |
| | .panel { |
| | background: transparent; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .component_set { |
| | background: transparent; |
| | box-shadow: none; |
| | } |
| | .panel:nth-child(2) .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .panel_header { |
| | display: none; |
| | } |
| | |
| | .gradio_wrapper .gradio_bg[is_embedded=false] .gradio_page .footer { |
| | visibility: hidden; |
| | } |
| | |
| | .labels { |
| | height: 20px; |
| | width: auto; |
| | } |
| | |
| | @media (max-width: 1000px){ |
| | .panel:nth-child(1) { |
| | margin-left: 0px; |
| | margin-right: 0px; |
| | } |
| | .gradio_wrapper .gradio_bg[theme=default] .gradio_interface .output_image { |
| | height: auto; |
| | } |
| | } |
| | """ |
| | creative_slider = gr.inputs.Radio(["Low", "Medium", "High"], default="Low", label='Creativity') |
| | textbox = gr.inputs.Textbox(placeholder='An apartment with two bedrooms and one bathroom', lines="3", |
| | label="DESCRIBE YOUR IDEAL APARTMENT") |
| | generated = gr.outputs.Image(label='Generated Layout') |
| | layout = gr.outputs.Textbox(label='Layout Coordinates') |
| |
|
| | iface = gr.Interface(fn=prompt_to_layout, inputs=[textbox, creative_slider], |
| | outputs=[generated, layout], |
| | css=custom_css, |
| | theme="default", |
| | allow_flagging='never', |
| | allow_screenshot=False, |
| | thumbnail="thumbnail_gradio.PNG") |
| |
|
| | iface.launch(enable_queue=True, share=True) |