sam-brause commited on
Commit
edd3afe
·
1 Parent(s): 85c4bdf

Updated handler with EndpointHandler class

Browse files
Files changed (1) hide show
  1. handler.py +22 -14
handler.py CHANGED
@@ -5,7 +5,7 @@ import torchvision.transforms as transforms
5
  from PIL import Image
6
  import json
7
 
8
- # Hugging Face model URL
9
  HF_REPO = "your-username/your-model-name"
10
  MODEL_FILENAME = "model_scripted_efficientnet.pt"
11
  MODEL_PATH = f"/repository/{MODEL_FILENAME}"
@@ -54,20 +54,28 @@ supported_issues = [
54
  "Brow Asymmetry"
55
  ]
56
 
57
- def predict(image_path):
58
- image = Image.open(image_path).convert("RGB")
59
- image = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
60
 
61
- with torch.no_grad():
62
- outputs = model(image)
 
 
63
 
64
- predictions = outputs.squeeze().tolist()
65
- output_labels = [label for label, prob in zip(supported_issues, predictions) if prob > 0.5]
 
66
 
67
- return {"predictions": output_labels}
 
68
 
69
- # Hugging Face API Inference Handler
70
- def handle_request(data):
71
- image_path = data["image_path"]
72
- results = predict(image_path)
73
- return json.dumps(results)
 
5
  from PIL import Image
6
  import json
7
 
8
+ # Hugging Face Model Information
9
  HF_REPO = "your-username/your-model-name"
10
  MODEL_FILENAME = "model_scripted_efficientnet.pt"
11
  MODEL_PATH = f"/repository/{MODEL_FILENAME}"
 
54
  "Brow Asymmetry"
55
  ]
56
 
57
+ class EndpointHandler:
58
+ def __init__(self, model_dir=None):
59
+ """Initialize the inference model."""
60
+ self.model = model
61
+ self.device = device
62
+ self.model.to(self.device)
63
+ self.model.eval()
64
+ print("Model loaded successfully.")
65
 
66
+ def __call__(self, data):
67
+ """Perform inference on an image."""
68
+ if "image" not in data:
69
+ return {"error": "No image provided"}
70
 
71
+ image_data = data["image"]
72
+ image = Image.open(image_data).convert("RGB")
73
+ image = transform(image).unsqueeze(0).to(self.device)
74
 
75
+ with torch.no_grad():
76
+ outputs = self.model(image)
77
 
78
+ predictions = outputs.squeeze().tolist()
79
+ output_labels = [label for label, prob in zip(supported_issues, predictions) if prob > 0.5]
80
+
81
+ return {"predictions": output_labels}