kaveh commited on
Commit
fbe1a2c
·
1 Parent(s): 57a673f

separate models for prediction

Browse files
Files changed (1) hide show
  1. S2FApp/predictor.py +21 -6
S2FApp/predictor.py CHANGED
@@ -89,31 +89,44 @@ class S2FPredictor:
89
  """
90
  self.model_type = model_type
91
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
92
- ckp_folder = ckp_folder or os.path.join(S2F_ROOT, "ckp")
 
 
 
 
 
 
 
 
93
 
94
  in_channels = 3 if model_type == "single_cell" else 1
95
- generator, _ = create_s2f_model(in_channels=in_channels)
 
96
  self.generator = generator
97
 
98
  if checkpoint_path:
99
  full_path = checkpoint_path
100
  if not os.path.isabs(checkpoint_path):
101
- full_path = os.path.join(ckp_folder, checkpoint_path)
 
 
102
  if not os.path.exists(full_path):
103
  raise FileNotFoundError(f"Checkpoint not found: {full_path}")
104
 
105
- # Single-cell: use load_checkpoint_with_expansion (handles 1ch->3ch if needed)
106
  if model_type == "single_cell":
107
  self.generator.load_checkpoint_with_expansion(full_path, strict=True)
108
  else:
109
  checkpoint = torch.load(full_path, map_location="cpu", weights_only=False)
110
- state = checkpoint.get("generator_state_dict", checkpoint)
111
  self.generator.load_state_dict(state, strict=True)
 
 
112
 
113
  self.generator = self.generator.to(self.device)
114
  self.generator.eval()
115
 
116
  self.norm_params = compute_settings_normalization() if model_type == "single_cell" else None
 
117
  self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json")
118
 
119
  def predict(self, image_path=None, image_array=None, substrate="fibroblasts_PDMS",
@@ -156,7 +169,9 @@ class S2FPredictor:
156
  with torch.no_grad():
157
  pred = self.generator(x)
158
 
159
- pred = (pred + 1.0) / 2.0 # Tanh to [0, 1]
 
 
160
  heatmap = pred[0, 0].cpu().numpy()
161
  force = sum_force_map(pred).item()
162
  pixel_sum = float(np.sum(heatmap))
 
89
  """
90
  self.model_type = model_type
91
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
92
+ ckp_base = os.path.join(S2F_ROOT, "ckp")
93
+ if not os.path.isdir(ckp_base):
94
+ project_root = os.path.dirname(S2F_ROOT)
95
+ if os.path.isdir(os.path.join(project_root, "ckp")):
96
+ ckp_base = os.path.join(project_root, "ckp")
97
+ subfolder = "single_cell" if model_type == "single_cell" else "spheroid"
98
+ ckp_dir = ckp_folder if ckp_folder else os.path.join(ckp_base, subfolder)
99
+ if not os.path.isdir(ckp_dir):
100
+ ckp_dir = ckp_base # fallback if subfolders not used
101
 
102
  in_channels = 3 if model_type == "single_cell" else 1
103
+ s2f_model_type = "s2f" if model_type == "single_cell" else "s2f_spheroid"
104
+ generator, _ = create_s2f_model(in_channels=in_channels, model_type=s2f_model_type)
105
  self.generator = generator
106
 
107
  if checkpoint_path:
108
  full_path = checkpoint_path
109
  if not os.path.isabs(checkpoint_path):
110
+ full_path = os.path.join(ckp_dir, checkpoint_path)
111
+ if not os.path.exists(full_path):
112
+ full_path = os.path.join(ckp_base, checkpoint_path) # try base folder
113
  if not os.path.exists(full_path):
114
  raise FileNotFoundError(f"Checkpoint not found: {full_path}")
115
 
 
116
  if model_type == "single_cell":
117
  self.generator.load_checkpoint_with_expansion(full_path, strict=True)
118
  else:
119
  checkpoint = torch.load(full_path, map_location="cpu", weights_only=False)
120
+ state = checkpoint.get("generator_state_dict") or checkpoint.get("model_state_dict") or checkpoint
121
  self.generator.load_state_dict(state, strict=True)
122
+ if hasattr(self.generator, "set_output_mode"):
123
+ self.generator.set_output_mode(use_tanh=False) # sigmoid [0,1] for inference
124
 
125
  self.generator = self.generator.to(self.device)
126
  self.generator.eval()
127
 
128
  self.norm_params = compute_settings_normalization() if model_type == "single_cell" else None
129
+ self._use_tanh_output = model_type == "single_cell" # single_cell uses tanh, spheroid uses sigmoid
130
  self.config_path = os.path.join(S2F_ROOT, "config", "substrate_settings.json")
131
 
132
  def predict(self, image_path=None, image_array=None, substrate="fibroblasts_PDMS",
 
169
  with torch.no_grad():
170
  pred = self.generator(x)
171
 
172
+ if self._use_tanh_output:
173
+ pred = (pred + 1.0) / 2.0 # Tanh [-1,1] to [0, 1]
174
+ # else: spheroid already outputs sigmoid [0, 1]
175
  heatmap = pred[0, 0].cpu().numpy()
176
  force = sum_force_map(pred).item()
177
  pixel_sum = float(np.sum(heatmap))