gpt-3d / app.py
Peeble's picture
Create app.py
8bd3cef verified
import torch
import gradio as gr
import matplotlib.pyplot as plt
from gpt3d.model import GPT3D
from gpt3d.mesh import save_obj
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPT3D().to(device)
MODEL_PATH = "gpt3d_local.pt"
def train_model(epochs):
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss()
for e in range(epochs):
noise = torch.randn(8, 1024, 3).to(device)
out = model(noise)
loss = loss_fn(out, noise)
opt.zero_grad()
loss.backward()
opt.step()
torch.save(model.state_dict(), MODEL_PATH)
return "✅ Trained locally"
def generate_mesh(steps):
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
pts = torch.randn(1, 1024, 3).to(device)
with torch.no_grad():
for _ in range(steps):
pts = model(pts)
pts = pts.squeeze().cpu().numpy()
save_obj(pts, "mesh.obj")
return "✅ mesh.obj generated"
def view_mesh():
pts = []
with open("mesh.obj") as f:
for line in f:
if line.startswith("v "):
_, x, y, z = line.split()
pts.append([float(x), float(y), float(z)])
pts = torch.tensor(pts)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(pts[:,0], pts[:,1], pts[:,2], s=1)
return fig
with gr.Blocks() as demo:
gr.Markdown("# GPT-3D Local Generator")
with gr.Tab("Train"):
epochs = gr.Slider(1, 50, value=5, step=1)
btn = gr.Button("Train")
out = gr.Textbox()
btn.click(train_model, epochs, out)
with gr.Tab("Generate OBJ"):
steps = gr.Slider(10, 200, value=50, step=10)
btn2 = gr.Button("Generate Mesh")
out2 = gr.Textbox()
btn2.click(generate_mesh, steps, out2)
with gr.Tab("View"):
btn3 = gr.Button("View Mesh")
plot = gr.Plot()
btn3.click(view_mesh, None, plot)
demo.launch()