Spaces:
Runtime error
Runtime error
| from g2p.model import Model | |
| from g2p.box_utils import centers_to_extents | |
| from g2p.retrieval import DataRetriever | |
| from g2p.floorplan import FloorPlan | |
| from g2p.align import align_fp_refine | |
| from g2p.add_archs import add_door_window | |
| import numpy as np | |
| import pickle | |
| import torch | |
| class App(): | |
| def __init__(self,model_path,device='cpu', data_path=None,tf_path=None,centroid_path=None,cluster_path=None): | |
| super().__init__() | |
| self.load_model(model_path,device=device) | |
| if data_path is not None: self.load_database(data_path,tf_path,centroid_path,cluster_path) | |
| def load_model(self,model_path,device='cpu'): | |
| model = Model() | |
| model.load_state_dict(torch.load(model_path)) | |
| model.to(device) | |
| model.eval() | |
| self.model = model | |
| self.device = device | |
| def load_database(self,data_path,tf_path,centroid_path,cluster_path): | |
| assert tf_path is not None | |
| assert centroid_path is not None | |
| assert cluster_path is not None | |
| self.data = pickle.load(open(data_path,'rb'))['data'] | |
| tf_train = np.load(tf_path) | |
| centroids = np.load(centroid_path) | |
| clusters = np.load(cluster_path) | |
| self.retriever = DataRetriever(tf_train,centroids,clusters) | |
| def retrieve(self,data_query,k=10,multi_clusters=False): | |
| index = self.retriever.retrieve_cluster(data_query,k=k,multi_clusters=multi_clusters) | |
| data = self.data[index] | |
| return data | |
| def transfer(self,data_boundary,data_graph): | |
| fp_boundary = FloorPlan(data_boundary) | |
| fp_graph = FloorPlan(data_graph) | |
| fp_transfer = fp_boundary.adapt_graph(fp_graph) | |
| fp_transfer.adjust_graph() | |
| fp_transfer.data.rType = fp_transfer.get_rooms(tensor=False) | |
| fp_transfer.data.rEdge = fp_transfer.get_triples(tensor=False)[:, [0, 2, 1]] | |
| return fp_transfer.data | |
| def forward(self,data,network_data=False): | |
| if network_data: | |
| data.box = np.concatenate([data.gtBoxNew,data.rType.reshape(-1,1)],axis=-1) | |
| data.edge = data.rEdge | |
| fp = FloorPlan(data) | |
| boundary, inside_box, rooms, attrs, triples = fp.get_test_data() | |
| boundary = boundary.unsqueeze(0).to(self.device) | |
| inside_box = inside_box.to(self.device) | |
| rooms = rooms.to(self.device) | |
| attrs = attrs.to(self.device) | |
| triples = triples.to(self.device) | |
| with torch.no_grad(): | |
| model_out = self.model( | |
| rooms, | |
| triples, | |
| boundary, | |
| obj_to_img = None, | |
| attributes = attrs, | |
| boxes_gt= None, | |
| generate = True, | |
| refine = True, | |
| relative = True, | |
| inside_box=inside_box | |
| ) | |
| boxes_pred, gene_layout, boxes_refine= model_out | |
| boxes_pred = centers_to_extents(boxes_pred)*255 | |
| boxes_pred = boxes_pred.squeeze().cpu().numpy().astype(int) | |
| boxes_refine = centers_to_extents(boxes_refine)*255 | |
| boxes_refine = boxes_refine.squeeze().cpu().numpy().astype(int) | |
| #gene_layout = gene_layout*boundary[:,:1] | |
| gene_preds = torch.argmax(gene_layout.softmax(1).detach(),dim=1).squeeze() | |
| gene_preds[boundary[0,0]==0]=13 | |
| gene_preds = gene_preds.cpu().numpy().astype(int) | |
| fp.data.predBox = boxes_pred | |
| fp.data.refineBox = boxes_refine | |
| fp.data.gene = gene_preds | |
| return fp.data | |
| def align(self,data): | |
| boxes_aligned, order, room_boundaries = align_fp_refine( | |
| data.boundary, | |
| data.refineBox, | |
| data.rType, | |
| data.rEdge, | |
| data.gene | |
| ) | |
| data.newBox = boxes_aligned | |
| data.order = order | |
| data.rBoundary = room_boundaries | |
| return data | |
| def decorate(self,data): | |
| doors,windows = add_door_window(data) | |
| data.doors = doors | |
| data.windows = windows | |
| return data | |
| def generate(self,data_boundary): | |
| data = self.retrieve(data_boundary)[0] | |
| data = self.transfer(data_boundary,data) | |
| data = self.forward(data) | |
| data = self.align(data) | |
| data = self.decorate(data) | |
| return data |