| import torch |
| import gradio as gr |
| import matplotlib.pyplot as plt |
|
|
| from gpt3d.model import GPT3D |
| from gpt3d.mesh import save_obj |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = GPT3D().to(device) |
|
|
| MODEL_PATH = "gpt3d_local.pt" |
|
|
|
|
| def train_model(epochs): |
| opt = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| loss_fn = torch.nn.MSELoss() |
|
|
| for e in range(epochs): |
| noise = torch.randn(8, 1024, 3).to(device) |
| out = model(noise) |
| loss = loss_fn(out, noise) |
|
|
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
|
|
| torch.save(model.state_dict(), MODEL_PATH) |
| return "✅ Trained locally" |
|
|
|
|
| def generate_mesh(steps): |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
| model.eval() |
|
|
| pts = torch.randn(1, 1024, 3).to(device) |
|
|
| with torch.no_grad(): |
| for _ in range(steps): |
| pts = model(pts) |
|
|
| pts = pts.squeeze().cpu().numpy() |
| save_obj(pts, "mesh.obj") |
|
|
| return "✅ mesh.obj generated" |
|
|
|
|
| def view_mesh(): |
| pts = [] |
| with open("mesh.obj") as f: |
| for line in f: |
| if line.startswith("v "): |
| _, x, y, z = line.split() |
| pts.append([float(x), float(y), float(z)]) |
|
|
| pts = torch.tensor(pts) |
|
|
| fig = plt.figure() |
| ax = fig.add_subplot(projection='3d') |
| ax.scatter(pts[:,0], pts[:,1], pts[:,2], s=1) |
| return fig |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# GPT-3D Local Generator") |
|
|
| with gr.Tab("Train"): |
| epochs = gr.Slider(1, 50, value=5, step=1) |
| btn = gr.Button("Train") |
| out = gr.Textbox() |
| btn.click(train_model, epochs, out) |
|
|
| with gr.Tab("Generate OBJ"): |
| steps = gr.Slider(10, 200, value=50, step=10) |
| btn2 = gr.Button("Generate Mesh") |
| out2 = gr.Textbox() |
| btn2.click(generate_mesh, steps, out2) |
|
|
| with gr.Tab("View"): |
| btn3 = gr.Button("View Mesh") |
| plot = gr.Plot() |
| btn3.click(view_mesh, None, plot) |
|
|
| demo.launch() |
|
|