scikit-learn/iris
Viewer β’ Updated β’ 150 β’ 6.65k β’ 11
INet is a simple fully-connected neural network trained on the Iris dataset using PyTorch. It classifies iris flowers into 4 categories based on 4 features: sepal length, sepal width, petal length, and petal width.
Architecture flow: Input(4) β Linear(64) β ReLU β Linear(32) β ReLU β Linear(16) β ReLU β Linear(8) β ReLU β Linear(4)
import torch
from model import INet # make sure INet class is in model.py
model = INet()
model.load_state_dict(torch.load("inet.pth"))
model.eval()
# Example usage:
sample_input = torch.tensor([[5.1, 3.5, 1.4, 0.2]])
pred = model(sample_input)
pred_class = pred.argmax(dim=1).item()
print(pred_class)
pip install torch
Unable to build the model tree, the base model loops to the model itself. Learn more.