sam-brause commited on
Commit
878a499
·
1 Parent(s): ee43b95

try claude image solution test

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. handler.py +56 -30
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
handler.py CHANGED
@@ -1,9 +1,10 @@
1
- import base64
2
  import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import io
 
6
  import logging
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
@@ -15,7 +16,6 @@ class EndpointHandler:
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device)
17
  self.model.eval()
18
-
19
  self.transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
@@ -24,7 +24,6 @@ class EndpointHandler:
24
  std=[0.229, 0.224, 0.225]
25
  )
26
  ])
27
-
28
  self.supported_issues = [
29
  "Dark Spots",
30
  "Dry Lips",
@@ -40,45 +39,72 @@ class EndpointHandler:
40
 
41
  def __call__(self, data):
42
  logger.info(f"Received data: {type(data)}")
43
-
44
  image = None
45
-
46
- if isinstance(data, bytes):
47
- logger.info("Input is bytes. Attempting to load image.")
48
- image = Image.open(io.BytesIO(data)).convert("RGB")
49
- elif isinstance(data, dict) and "inputs" in data:
50
- logger.info("Input is a dictionary with 'inputs' key.")
51
- if isinstance(data["inputs"], bytes):
52
- logger.info("'inputs' value is bytes. Attempting to load image.")
53
- image = Image.open(io.BytesIO(data["inputs"])).convert("RGB")
54
- elif isinstance(data["inputs"], str):
55
- logger.info("'inputs' value is a string. Attempting to decode base64.")
56
- try:
57
- image_data = base64.b64decode(data["inputs"])
58
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
59
- except Exception as e:
60
- logger.error(f"Failed to decode base64 image: {e}")
61
- else:
62
- logger.warning("'inputs' value is not bytes or base64 string.")
63
- else:
64
- logger.error("Unsupported input format.")
65
- raise ValueError("Unsupported input format")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if image is None:
68
- logger.error("Could not load image from input data.")
69
- raise ValueError("Could not load image from input data.")
70
 
71
  logger.info("Image loaded successfully. Applying transformations.")
72
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
73
-
74
  with torch.no_grad():
75
  logger.info("Running inference.")
76
  outputs = self.model(image_tensor)
77
 
78
- # Use raw outputs (removing softmax)
79
  predictions = outputs.squeeze().tolist()
80
  output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5]
81
-
82
  logger.info(f"Predictions: {output}")
83
  return {"predictions": output}
84
 
 
 
1
  import torch
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
  import io
5
+ import base64
6
  import logging
7
+ import json
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
 
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  self.model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt", map_location=self.device)
18
  self.model.eval()
 
19
  self.transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
 
24
  std=[0.229, 0.224, 0.225]
25
  )
26
  ])
 
27
  self.supported_issues = [
28
  "Dark Spots",
29
  "Dry Lips",
 
39
 
40
  def __call__(self, data):
41
  logger.info(f"Received data: {type(data)}")
 
42
  image = None
43
+
44
+ try:
45
+ # Handle string input (from Hugging Face interface)
46
+ if isinstance(data, str):
47
+ logger.info("Input is string. Attempting to parse as JSON.")
48
+ data = json.loads(data)
49
+
50
+ # Handle various input formats
51
+ if isinstance(data, dict):
52
+ if "inputs" in data:
53
+ input_data = data["inputs"]
54
+ logger.info(f"Input data type: {type(input_data)}")
55
+
56
+ # Handle base64 encoded string
57
+ if isinstance(input_data, str):
58
+ logger.info("Attempting to decode base64 string")
59
+ try:
60
+ # Remove potential base64 prefix
61
+ if "base64," in input_data:
62
+ input_data = input_data.split("base64,")[1]
63
+ image_bytes = base64.b64decode(input_data)
64
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
65
+ except Exception as e:
66
+ logger.error(f"Failed to decode base64: {str(e)}")
67
+
68
+ # Handle raw bytes
69
+ elif isinstance(input_data, bytes):
70
+ logger.info("Processing raw bytes input")
71
+ image = Image.open(io.BytesIO(input_data)).convert("RGB")
72
+
73
+ # Handle list input (from Hugging Face interface)
74
+ elif isinstance(input_data, list):
75
+ logger.info("Processing list input")
76
+ if len(input_data) > 0 and isinstance(input_data[0], str):
77
+ try:
78
+ # Remove potential base64 prefix
79
+ if "base64," in input_data[0]:
80
+ input_data[0] = input_data[0].split("base64,")[1]
81
+ image_bytes = base64.b64decode(input_data[0])
82
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
83
+ except Exception as e:
84
+ logger.error(f"Failed to decode base64 from list: {str(e)}")
85
+
86
+ # Handle direct bytes input
87
+ elif isinstance(data, bytes):
88
+ logger.info("Processing direct bytes input")
89
+ image = Image.open(io.BytesIO(data)).convert("RGB")
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error processing input: {str(e)}")
93
+ raise ValueError(f"Error processing input: {str(e)}")
94
 
95
  if image is None:
96
+ logger.error("Could not load image from input data")
97
+ raise ValueError("Could not load image from input data")
98
 
99
  logger.info("Image loaded successfully. Applying transformations.")
100
  image_tensor = self.transform(image).unsqueeze(0).to(self.device)
101
+
102
  with torch.no_grad():
103
  logger.info("Running inference.")
104
  outputs = self.model(image_tensor)
105
 
 
106
  predictions = outputs.squeeze().tolist()
107
  output = [issue for issue, prob in zip(self.supported_issues, predictions) if prob > 0.5]
 
108
  logger.info(f"Predictions: {output}")
109
  return {"predictions": output}
110