sam-brause commited on
Commit
5d6739d
·
1 Parent(s): edd3afe

revise handler based on deepseek attempt

Browse files
Files changed (1) hide show
  1. handler.py +57 -74
handler.py CHANGED
@@ -1,81 +1,64 @@
1
- import os
2
  import torch
3
- import requests
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
- import json
7
 
8
- # Hugging Face Model Information
9
- HF_REPO = "your-username/your-model-name"
10
- MODEL_FILENAME = "model_scripted_efficientnet.pt"
11
- MODEL_PATH = f"/repository/{MODEL_FILENAME}"
 
 
12
 
13
- def download_model():
14
- """Download model file from Hugging Face if it does not exist."""
15
- if not os.path.exists(MODEL_PATH):
16
- print("Model file not found! Downloading from Hugging Face...")
17
- model_url = f"https://huggingface.co/{HF_REPO}/resolve/main/{MODEL_FILENAME}"
18
- response = requests.get(model_url, stream=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- if response.status_code == 200:
21
- with open(MODEL_PATH, "wb") as f:
22
- for chunk in response.iter_content(chunk_size=8192):
23
- f.write(chunk)
24
- print("Model downloaded successfully!")
25
- else:
26
- raise RuntimeError(f"Failed to download model, status code: {response.status_code}")
 
27
 
28
- # Ensure model is downloaded before loading
29
- download_model()
30
-
31
- # Load Model
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- model = torch.jit.load(MODEL_PATH, map_location=device)
34
- model.eval()
35
-
36
- # Define Transform
37
- transform = transforms.Compose([
38
- transforms.Resize((224, 224)),
39
- transforms.ToTensor(),
40
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
41
- ])
42
-
43
- # Supported labels
44
- supported_issues = [
45
- "Dark Spots",
46
- "Dry Lips",
47
- "Forehead Wrinkles",
48
- "Jowls",
49
- "Nasolabial Folds",
50
- "Prejowl Sulcus",
51
- "Thin Lips",
52
- "Under Eye Hollow",
53
- "Under Eye Wrinkles",
54
- "Brow Asymmetry"
55
- ]
56
-
57
- class EndpointHandler:
58
- def __init__(self, model_dir=None):
59
- """Initialize the inference model."""
60
- self.model = model
61
- self.device = device
62
- self.model.to(self.device)
63
- self.model.eval()
64
- print("Model loaded successfully.")
65
-
66
- def __call__(self, data):
67
- """Perform inference on an image."""
68
- if "image" not in data:
69
- return {"error": "No image provided"}
70
-
71
- image_data = data["image"]
72
- image = Image.open(image_data).convert("RGB")
73
- image = transform(image).unsqueeze(0).to(self.device)
74
-
75
- with torch.no_grad():
76
- outputs = self.model(image)
77
-
78
- predictions = outputs.squeeze().tolist()
79
- output_labels = [label for label, prob in zip(supported_issues, predictions) if prob > 0.5]
80
-
81
- return {"predictions": output_labels}
 
 
1
  import torch
 
2
  import torchvision.transforms as transforms
3
  from PIL import Image
4
+ import io
5
 
6
+ # Load the model
7
+ def model_fn(model_dir):
8
+ # Load the scripted PyTorch model
9
+ model = torch.jit.load(f"{model_dir}/model_scripted_efficientnet.pt")
10
+ model.eval()
11
+ return model
12
 
13
+ # Preprocess the input image
14
+ def input_fn(request_body, request_content_type):
15
+ if request_content_type == "image/jpeg" or request_content_type == "image/png":
16
+ # Load the image from the request body
17
+ image = Image.open(io.BytesIO(request_body)).convert("RGB")
18
+
19
+ # Define the image transformation
20
+ transform = transforms.Compose([
21
+ transforms.Resize((224, 224)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406], # Mean for ImageNet
25
+ std=[0.229, 0.224, 0.225] # Std for ImageNet
26
+ )
27
+ ])
28
+
29
+ # Apply the transformation and add a batch dimension
30
+ image = transform(image).unsqueeze(0)
31
+ return image
32
+ else:
33
+ raise ValueError(f"Unsupported content type: {request_content_type}")
34
 
35
+ # Run inference
36
+ def predict_fn(input_data, model):
37
+ with torch.no_grad():
38
+ # Get the model predictions
39
+ outputs = model(input_data)
40
+ # Convert the outputs to probabilities using softmax (if needed)
41
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
42
+ return probabilities.squeeze().tolist()
43
 
44
+ # Postprocess the output
45
+ def output_fn(predictions, content_type):
46
+ # Define the supported issues and their corresponding labels
47
+ supported_issues = [
48
+ "Dark Spots",
49
+ "Dry Lips",
50
+ "Forehead Wrinkles",
51
+ "Jowls",
52
+ "Nasolabial Folds",
53
+ "Prejowl Sulcus",
54
+ "Thin Lips",
55
+ "Under Eye Hollow",
56
+ "Under Eye Wrinkles",
57
+ "Brow Asymmetry"
58
+ ]
59
+
60
+ # Filter issues with probability > 0.5
61
+ output = [issue for issue, prob in zip(supported_issues, predictions) if prob > 0.5]
62
+
63
+ # Return the output as a JSON response
64
+ return {"predictions": output}