Image Classification
PyTorch
torch
resnet
diagrams
computer-vision
Ayamohamed commited on
Commit
5ff2677
·
verified ·
1 Parent(s): 96ffc7e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -29
README.md CHANGED
@@ -73,25 +73,12 @@ You can download the model and load it using `torch`.
73
  import torch
74
  from huggingface_hub import hf_hub_download
75
 
76
- # Download the model
77
- model_path = hf_hub_download("Ayamohamed/DiaNoneclassi", "diagram_classifier_full.pth")
78
 
79
- # Load the model
80
- from torchvision import models
81
- import torch.nn as nn
82
-
83
- class DiagramClassifier(nn.Module):
84
- def __init__(self):
85
- super().__init__()
86
- self.model = models.resnet18(pretrained=False)
87
- self.model.fc = nn.Linear(self.model.fc.in_features, 2)
88
-
89
- def forward(self, x):
90
- return self.model(x)
91
-
92
- model = DiagramClassifier()
93
- model.load_state_dict(torch.load(model_path, map_location="cpu"))
94
- model.eval() # Set to evaluation mode
95
 
96
  ```
97
  ### **2️⃣ Preprocess and Classify an Image**
@@ -105,20 +92,18 @@ transform = transforms.Compose([
105
  transforms.ToTensor(),
106
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
107
  ])
108
-
109
- def classify_image(image_path):
110
- image = Image.open(image_path).convert("RGB")
111
- image = transform(image).unsqueeze(0) # Add batch dimension
112
-
113
  with torch.no_grad():
114
- output = model(image)
115
- predicted_class = output.argmax(1).item()
 
 
116
 
117
- labels = ["diagram", "none"]
118
- return labels[predicted_class]
119
 
120
- # Test
121
- print(classify_image("test_image.jpg"))
122
 
123
  ```
124
 
 
73
  import torch
74
  from huggingface_hub import hf_hub_download
75
 
76
+ # Download model from Hugging Face Hub
77
+ model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="dia_none_classifier_full.pth")
78
 
79
+ # Load model
80
+ model_hg = torch.load(model_path)
81
+ model_hg.eval() # Set to evaluation mode
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  ```
84
  ### **2️⃣ Preprocess and Classify an Image**
 
92
  transforms.ToTensor(),
93
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
94
  ])
95
+ def predict(image_path):
96
+ image = Image.open(image_path).convert("RGB")
97
+ image = transform(image).unsqueeze(0)
 
 
98
  with torch.no_grad():
99
+ output = model_hg(image)
100
+ class_idx = torch.argmax(output, dim=1).item()
101
+
102
+ return "Diagram" if class_idx == 0 else "Not Diagram"
103
 
104
+ # Example usage
105
+ print(predict("my-diagram-classifier/31188_1536932698.jpg"))
106
 
 
 
107
 
108
  ```
109