Example wrong
#1
by
cmahnke
- opened
The example won't work, see:
ImportError: cannot import name 'SimpleCNN' from 'transformers'
This seems to work:
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from PIL import Image
import numpy as np
# Define the corrected SimpleCNN architecture
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) # Adjusted to 16 output channels
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) # Adjusted to 32 output channels
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) # Adjusted to 32 output channels
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(32 * 16 * 16, 32) # Adjusted input and output dimensions
self.fc2 = nn.Linear(32, 2) # Adjusted input dimension
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(x.size(0), -1) # Flatten
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load the model
model = SimpleCNN()
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()
# Function to predict orientation
def predict_orientation(image_path, model):
img = Image.open(image_path).convert('L') # Load image in grayscale
img = img.resize((128, 128)) # Resize to 128x128
img_tensor = torch.tensor(np.array(img) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
is_rotated = torch.argmax(output, dim=1).item() == 1
return "Rotated" if is_rotated else "Normal"
# Example usage
result = predict_orientation("example.jpg", model)
print(f"Image Orientation: {result}")