Ystar124's picture
Add App
74e8a31
import gradio as gr
import torch
from model import GraphGenerator, PropertyPredictor, adj_to_edge_index # use correct import
from torch_geometric.data import Data
# Load or initialize models
generator = GraphGenerator(latent_dim=16, num_node_features=1, num_nodes=5)
predictor = PropertyPredictor(in_channels=1, out_channels=3)
# Load pretrained weights if available
# generator.load_state_dict(torch.load("generator.pth"))
# predictor.load_state_dict(torch.load("predictor.pth"))
def generate_graphs(property_index=0, top_k=5):
generator.eval()
predictor.eval()
z = torch.randn(100, generator.latent_dim)
node_feats, adj = generator(z)
candidates = []
for n, a in zip(node_feats, adj):
edge_index = adj_to_edge_index(a)
if edge_index.size(1) == 0:
continue
data = Data(x=n, edge_index=edge_index)
score_vec = predictor(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
score = score_vec[0][property_index].item()
candidates.append((score, score_vec.tolist(), data.x.size(0), data.edge_index.size(1)))
candidates = sorted(candidates, key=lambda x: -x[0])[:top_k]
results = [
f"Score: {c[0]:.3f}, Properties: {c[1]}, Nodes: {c[2]}, Edges: {c[3]}"
for c in candidates
]
return "\n".join(results)
iface = gr.Interface(
fn=generate_graphs,
inputs=[
gr.Slider(0, 2, step=1, label="Property Index (0: Conductivity, 1: Porosity, 2: Surface Area)"),
gr.Slider(1, 20, step=1, label="Top K Results")
],
outputs="text",
title="GraphGAN + GNN Property Predictor",
description="Generates molecules and predicts multi-properties, filtering top candidates by the selected property."
)
if __name__ == "__main__":
iface.launch()