IFMedTechdemo commited on
Commit
d09919c
·
verified ·
1 Parent(s): e0b1c0e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -1
main.py CHANGED
@@ -16,6 +16,28 @@ PY_MODULES = {
16
  "wrinkle_unet.py": "WrinkleDetector"
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def dynamic_import(module_path, class_name):
20
  spec = importlib.util.spec_from_file_location(class_name, module_path)
21
  module = importlib.util.module_from_spec(spec)
@@ -65,7 +87,8 @@ def analyze_skin(image: np.ndarray, analysis_type: str) -> np.ndarray:
65
  # else:
66
  # print(f"Oiliness detection error: {result.get('error')}")
67
  elif analysis_type == "Wrinkles":
68
- detector = detector_classes["WrinkleDetector"](image)
 
69
  result = detector.predict_json()
70
  if result.get("detected") is not None:
71
  output = detector.draw_json(result)
 
16
  "wrinkle_unet.py": "WrinkleDetector"
17
  }
18
 
19
+ def load_model(token):
20
+ repo_id = "IFMedTech/Skin-Analysis"
21
+ filename = "model/wrinkles_unet_v1.pth"
22
+ # token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
23
+
24
+ # if not token:
25
+ # raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
26
+
27
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model = smp.Unet(
31
+ encoder_name="resnet34",
32
+ encoder_weights=None,
33
+ in_channels=3,
34
+ classes=1
35
+ )
36
+ model.load_state_dict(torch.load(model_path, map_location=device))
37
+ model.to(device)
38
+ model.eval()
39
+ return model, device
40
+
41
  def dynamic_import(module_path, class_name):
42
  spec = importlib.util.spec_from_file_location(class_name, module_path)
43
  module = importlib.util.module_from_spec(spec)
 
87
  # else:
88
  # print(f"Oiliness detection error: {result.get('error')}")
89
  elif analysis_type == "Wrinkles":
90
+ model, device = load_model(token)
91
+ detector = detector_classes["WrinkleDetector"](image, model, device)
92
  result = detector.predict_json()
93
  if result.get("detected") is not None:
94
  output = detector.draw_json(result)