|
|
import io |
|
|
from PIL import Image as PImage |
|
|
import numpy as np |
|
|
from collections import defaultdict |
|
|
import cv2 |
|
|
from typing import Tuple, List |
|
|
from scipy.spatial.distance import cdist |
|
|
|
|
|
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary |
|
|
from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping |
|
|
|
|
|
from geom_solver import GeomSolver, my_empty_solution, fully_connected_solution |
|
|
|
|
|
|
|
|
def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]: |
|
|
vertices0, edges0 = my_empty_solution() |
|
|
try: |
|
|
vertices, edges = GeomSolver().solve(entry) |
|
|
except: |
|
|
print('ERROR') |
|
|
vertices, edges = fully_connected_solution() |
|
|
|
|
|
|
|
|
if (len(edges) < 1) and (len(vertices) >= 2): |
|
|
|
|
|
edges = edges0 |
|
|
|
|
|
if (len(vertices) < 2) or (len(edges) < 1): |
|
|
|
|
|
vertices, edges = vertices0, edges0 |
|
|
|
|
|
if visualize: |
|
|
from hoho.viz3d import plot_estimate_and_gt |
|
|
plot_estimate_and_gt( vertices, |
|
|
edges, |
|
|
entry['wf_vertices'], |
|
|
entry['wf_edges']) |
|
|
if vertices.shape[-1] != 3: |
|
|
print("Wrong size") |
|
|
vertices, edges = vertices0, edges0 |
|
|
return entry['__key__'], vertices, edges |
|
|
|