tejani commited on
Commit
fe62c07
·
verified ·
1 Parent(s): 7a89e4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -31
app.py CHANGED
@@ -12,12 +12,29 @@ MODELS_DIR = Path("models")
12
  INPUT_DIR.mkdir(exist_ok=True)
13
  OUTPUT_DIR.mkdir(exist_ok=True)
14
 
15
- # Load pre-trained models
16
  def load_model(model_path, use_cpu=False):
17
- # Explicitly set weights_only=False to load the full model object
18
- model = torch.load(model_path, map_location="cpu" if use_cpu else None, weights_only=False)
19
- if not use_cpu and torch.cuda.is_available():
20
- model = model.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model.eval()
22
  return model
23
 
@@ -33,36 +50,24 @@ def process_image(input_path, tile_size=512, seamless=False, use_cpu=False):
33
 
34
  # Convert to tensor
35
  img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
36
- if not use_cpu and torch.cuda.is_available():
37
- img_tensor = img_tensor.cuda()
 
 
38
 
39
  # Generate maps
40
  with torch.no_grad():
41
- # Normal map
42
- normal_map = normal_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze()
43
- # Franken map (contains Displacement and Roughness)
44
- franken_map = franken_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze()
45
 
46
  # Post-process maps
47
- # Normal map (RGB)
48
- if normal_map.ndim == 3:
49
- normal_map = normal_map.transpose(1, 2, 0)
50
- else:
51
- normal_map = np.stack([normal_map] * 3, axis=-1) # Convert grayscale to RGB if needed
52
- normal_map = (normal_map * 255).clip(0, 255).astype(np.uint8)
53
-
54
- # Franken map: Extract Displacement (red) and Roughness (green)
55
- if franken_map.ndim == 3:
56
- franken_map = franken_map.transpose(1, 2, 0)
57
- # Displacement map (red channel, grayscale)
58
- disp_map = franken_map[:, :, 0] # Red channel
59
- disp_map = (disp_map * 255).clip(0, 255).astype(np.uint8)
60
- disp_map = np.stack([disp_map] * 3, axis=-1) # Convert to RGB for Gradio display
61
-
62
- # Roughness map (green channel, grayscale)
63
- rough_map = franken_map[:, :, 1] # Green channel
64
- rough_map = (rough_map * 255).clip(0, 255).astype(np.uint8)
65
- rough_map = np.stack([rough_map] * 3, axis=-1) # Convert to RGB for Gradio display
66
 
67
  # Define output paths
68
  base_name = input_path.stem
@@ -114,4 +119,4 @@ interface = gr.Interface(
114
  )
115
 
116
  if __name__ == "__main__":
117
- interface.launch()
 
12
  INPUT_DIR.mkdir(exist_ok=True)
13
  OUTPUT_DIR.mkdir(exist_ok=True)
14
 
15
+ # Function to load pre-trained models
16
  def load_model(model_path, use_cpu=False):
17
+ if not model_path.exists():
18
+ raise FileNotFoundError(f"Model file not found: {model_path}")
19
+
20
+ device = "cpu" if use_cpu or not torch.cuda.is_available() else "cuda"
21
+
22
+ # Load state_dict if the model was saved that way
23
+ model_state = torch.load(model_path, map_location=device)
24
+
25
+ # If a full model object was saved, load it directly
26
+ if isinstance(model_state, torch.nn.Module):
27
+ model = model_state
28
+ else:
29
+ # If saved as state_dict, we need a model architecture (Assuming CNN or custom model)
30
+ model = torch.nn.Sequential(
31
+ torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
32
+ torch.nn.ReLU(),
33
+ torch.nn.Conv2d(64, 3, kernel_size=3, padding=1)
34
+ )
35
+ model.load_state_dict(model_state)
36
+
37
+ model.to(device)
38
  model.eval()
39
  return model
40
 
 
50
 
51
  # Convert to tensor
52
  img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
53
+ img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
54
+
55
+ device = "cpu" if use_cpu or not torch.cuda.is_available() else "cuda"
56
+ img_tensor = img_tensor.to(device)
57
 
58
  # Generate maps
59
  with torch.no_grad():
60
+ normal_map = normal_model(img_tensor).cpu().numpy().squeeze()
61
+ franken_map = franken_model(img_tensor).cpu().numpy().squeeze()
 
 
62
 
63
  # Post-process maps
64
+ normal_map = (normal_map.transpose(1, 2, 0) * 255).clip(0, 255).astype(np.uint8)
65
+ disp_map = (franken_map[0] * 255).clip(0, 255).astype(np.uint8)
66
+ rough_map = (franken_map[1] * 255).clip(0, 255).astype(np.uint8)
67
+
68
+ # Convert grayscale to RGB for Gradio display
69
+ disp_map = np.stack([disp_map] * 3, axis=-1)
70
+ rough_map = np.stack([rough_map] * 3, axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Define output paths
73
  base_name = input_path.stem
 
119
  )
120
 
121
  if __name__ == "__main__":
122
+ interface.launch()