IFMedTechdemo commited on
Commit
fd03932
·
verified ·
1 Parent(s): eb6bc72

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -19
main.py CHANGED
@@ -15,31 +15,47 @@ PY_MODULES = {
15
  "texture.py": "TextureDetector",
16
  "skin_tone.py": "SkinToneDetector",
17
  "oiliness.py": "OilinessDetector",
18
- "wrinkle_unet.py": "WrinkleDetector",
19
  "age.py": "AgePredictor"
20
  }
21
-
22
  def load_model(token):
23
- repo_id = "IFMedTech/Skin-Analysis"
24
- filename = "model/wrinkles_unet_v1.pth"
25
- # token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
 
 
 
26
 
27
- # if not token:
28
- # raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
 
 
 
 
29
 
30
- model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
 
 
 
 
 
31
 
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- model = smp.Unet(
34
- encoder_name="resnet34",
35
- encoder_weights=None,
36
- in_channels=3,
37
- classes=1
38
- )
39
- model.load_state_dict(torch.load(model_path, map_location=device))
40
- model.to(device)
41
- model.eval()
42
- return model, device
 
 
 
 
 
43
 
44
  def dynamic_import(module_path, class_name):
45
  spec = importlib.util.spec_from_file_location(class_name, module_path)
 
15
  "texture.py": "TextureDetector",
16
  "skin_tone.py": "SkinToneDetector",
17
  "oiliness.py": "OilinessDetector",
18
+ "wrinkle.py": "WrinkleDetector",
19
  "age.py": "AgePredictor"
20
  }
 
21
  def load_model(token):
22
+ """Download and load ONNX model"""
23
+ model_path = hf_hub_download(
24
+ repo_id=REPO_ID,
25
+ filename="model/wrinkle_model.onnx", # Adjust path if needed
26
+ token=token
27
+ )
28
 
29
+ # Create ONNX Runtime session
30
+ session = ort.InferenceSession(
31
+ model_path,
32
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
33
+ if torch.cuda.is_available() else ['CPUExecutionProvider']
34
+ )
35
 
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ return session, device
38
+ # def load_model(token):
39
+ # repo_id = "IFMedTech/Skin-Analysis"
40
+ # filename = "model/wrinkles_unet_v1.pth"
41
+ # # token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
42
 
43
+ # # if not token:
44
+ # # raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
45
+
46
+ # model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
47
+
48
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ # model = smp.Unet(
50
+ # encoder_name="resnet34",
51
+ # encoder_weights=None,
52
+ # in_channels=3,
53
+ # classes=1
54
+ # )
55
+ # model.load_state_dict(torch.load(model_path, map_location=device))
56
+ # model.to(device)
57
+ # model.eval()
58
+ # return model, device
59
 
60
  def dynamic_import(module_path, class_name):
61
  spec = importlib.util.spec_from_file_location(class_name, module_path)