Example wrong

#1
by cmahnke - opened

The example won't work, see:

ImportError: cannot import name 'SimpleCNN' from 'transformers' 

This seems to work:

import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from PIL import Image
import numpy as np

# Define the corrected SimpleCNN architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)  # Adjusted to 16 output channels
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # Adjusted to 32 output channels
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)  # Adjusted to 32 output channels
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 16 * 16, 32)  # Adjusted input and output dimensions
        self.fc2 = nn.Linear(32, 2)  # Adjusted input dimension

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load the model
model = SimpleCNN()
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
model.eval()

# Function to predict orientation
def predict_orientation(image_path, model):
    img = Image.open(image_path).convert('L')  # Load image in grayscale
    img = img.resize((128, 128))               # Resize to 128x128
    img_tensor = torch.tensor(np.array(img) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    with torch.no_grad():
        output = model(img_tensor)
    is_rotated = torch.argmax(output, dim=1).item() == 1
    return "Rotated" if is_rotated else "Normal"

# Example usage
result = predict_orientation("example.jpg", model)
print(f"Image Orientation: {result}")

Sign up or log in to comment