File size: 4,709 Bytes
33803b5
 
 
5645b6a
33803b5
 
 
 
 
 
 
 
 
 
 
 
 
 
5645b6a
33803b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5645b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33803b5
 
 
 
5645b6a
33803b5
 
5645b6a
 
 
33803b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import gradio as gr
import plotly.graph_objects as go
import trimesh

device = torch.device("cpu")
model = torch.jit.load('model_scripted.pt').to(device)

def normalize_vertices(verts):
    # Normalize verts
    center = verts.mean(0)
    verts = verts - center
    scale = max(verts.abs().max(0)[0])
    return verts / scale

def plot_3d_results(verts, faces, uv_seam_edge_indices):
    # Convert vertices to NumPy for easier manipulation
    verts_np = verts.cpu().numpy()
    faces_np = faces.cpu().numpy()

    # Prepare the vertex coordinates for the Mesh3d plot
    x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2]
    i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2]

    # Create the 3D mesh plot
    mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh')

    # Prepare lines for the predicted edges
    edge_x, edge_y, edge_z = [], [], []
    for edge in uv_seam_edge_indices:
        x0, y0, z0 = verts_np[edge[0]]
        x1, y1, z1 = verts_np[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])

    # Create a trace for edges
    edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2),
                               name='Predicted Edges')

    # Create a figure and add the mesh and edges
    fig = go.Figure(data=[mesh, edges_trace])
    fig.update_layout(scene=dict(
        xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True,
                   zerolinecolor="white"),
        yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True,
                   zerolinecolor="white"),
        zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True,
                   zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))),
        title_text='Predicted Edges')

    # return the figure
    return fig


def generate_prediction(file_input, treshold_value=0.5):
    # Load the triangle mesh
    mesh = trimesh.load_mesh(file_input)

    # For production, we should use a faster method to preprocess the mesh!

    # Convert vertices to a PyTorch tensor
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)

    # Initialize containers for unique vertices and mapping
    unique_vertices = []
    vertex_mapping = {}
    new_faces = []

    # Populate unique vertices and create new faces with updated indices
    for face in mesh.faces:
        new_face = []
        for orig_index in face:
            vertex = tuple(vertices[orig_index].tolist())  # Convert to tuple (hashable)
            if vertex not in vertex_mapping:
                vertex_mapping[vertex] = len(unique_vertices)
                unique_vertices.append(vertices[orig_index])
            new_face.append(vertex_mapping[vertex])
        new_faces.append(new_face)

    # Create edge set to ensure uniqueness
    edge_set = set()
    for face in new_faces:
        # Unpack the vertex indices
        v1, v2, v3 = face
        # Create undirected edges (use tuple sorting to ensure uniqueness)
        edge_set.add(tuple(sorted((v1, v2))))
        edge_set.add(tuple(sorted((v2, v3))))
        edge_set.add(tuple(sorted((v1, v3))))

    # Convert edges back to tensor
    edges = torch.tensor(list(edge_set), dtype=torch.long)

    # Convert unique vertices and new faces back to tensors
    verts = torch.stack(unique_vertices)
    faces = torch.tensor(new_faces, dtype=torch.long)

    model.eval()

    with torch.no_grad():
        test_outputs_logits = model(verts, edges).to(device)
        test_outputs = torch.sigmoid(test_outputs_logits).to(device)
        test_predictions = (test_outputs > treshold_value).int().cpu()

    uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1
    uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist()

    # Return the HTML content generated by plot_3d_results
    return plot_3d_results(verts, faces, uv_seam_edges)


def run_gradio():
    with gr.Blocks() as demo:
        gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.")

        with gr.Row():
            model3d_input = gr.File(label="Sphere Prototype Model", value="sphere.obj")
            with gr.Column():
                model3d_output = gr.Plot()
                treshold_value = gr.Slider(minimum=0, maximum=1, value=0.5, label="Threshold")

        button = gr.Button("Predict")
        button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output)

    demo.launch()


run_gradio()