PulinduVR commited on
Commit
1032eea
·
1 Parent(s): b9fca8a

Add custom handler for inference widget

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. code/handler.py +53 -0
README.md CHANGED
@@ -1,6 +1,7 @@
1
  ---
2
  language: en
3
  license: mit
 
4
  tags:
5
  - agricultural-ai
6
  - maize
 
1
  ---
2
  language: en
3
  license: mit
4
+ pipeline_tag: image-classification
5
  tags:
6
  - agricultural-ai
7
  - maize
code/handler.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import io
6
+
7
+ class EndpointHandler():
8
+ def __init__(self, path=""):
9
+ # 1. Define device
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # 2. Define class names (Matches alphabetical order used in training)
13
+ self.class_names = ['Gray Leaf Spot', 'Healthy']
14
+
15
+ # 3. Initialize Model Architecture (Update if using EfficientNet)
16
+ # Note: You can make this dynamic or hardcode it to your best model
17
+ self.model = models.resnet50(weights=None)
18
+ self.model.fc = nn.Linear(self.model.fc.in_features, len(self.class_names))
19
+
20
+ # 4. Load weights (Hugging Face passes the folder path in 'path')
21
+ # Ensure 'model.pth' is the name of your file in the root
22
+ state_dict = torch.load(f"{path}/model.pth", map_location=self.device)
23
+ self.model.load_state_dict(state_dict)
24
+ self.model.to(self.device)
25
+ self.model.eval()
26
+
27
+ # 5. Define Preprocessing
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((256, 256)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
32
+ ])
33
+
34
+ def __call__(self, data):
35
+ # Data is a dictionary containing the image bytes
36
+ inputs = data.pop("inputs", data)
37
+
38
+ # Convert bytes to PIL Image
39
+ image = Image.open(io.BytesIO(inputs)).convert("RGB")
40
+
41
+ # Preprocess
42
+ tensor = self.transform(image).unsqueeze(0).to(self.device)
43
+
44
+ # Inference
45
+ with torch.no_grad():
46
+ outputs = self.model(tensor)
47
+ probs = torch.nn.functional.softmax(outputs, dim=1)
48
+ conf, pred_idx = torch.max(probs, 1)
49
+
50
+ # Return formatted result for the widget
51
+ return [
52
+ {"label": self.class_names[pred_idx.item()], "score": conf.item()}
53
+ ]