mvi-ai-engine / language /models.py
Musombi's picture
Create models.py
34897f3 verified
import torch
import torch.nn as nn
# ===============================
# Soil Model (matches training)
# ===============================
class SoilModel(nn.Module):
def __init__(self, input_size=8, num_classes=10):
super().__init__()
self.num_classes = num_classes
self.net = nn.Sequential(
nn.Linear(input_size, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, num_classes)
)
def forward(self, x):
return self.net(x)
# ===============================
# Vision Model (matches training CNN)
# ===============================
class VisionModel(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.num_classes = num_classes
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(128 * 16 * 16, 256),
nn.ReLU(),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)