tejani commited on
Commit
7e25f9d
·
verified ·
1 Parent(s): 82a4cce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -35
app.py CHANGED
@@ -5,57 +5,95 @@ import torch
5
  import os
6
  from pathlib import Path
7
 
8
- # Placeholder for the actual model loading and processing logic
9
- # Replace this with the actual code from generate.py
 
 
 
 
 
 
10
  def load_model(model_path, use_cpu=False):
11
- # Example: Load your ESRGAN model here
12
- # This is a placeholder; adapt it based on the actual model loading in generate.py
13
- model = torch.load(model_path)
14
  if not use_cpu and torch.cuda.is_available():
15
  model = model.cuda()
16
  model.eval()
17
  return model
18
 
19
- def process_image(input_image, tile_size=512, seamless=False, use_cpu=False):
20
- # Convert Gradio input (PIL image) to OpenCV format
21
- img = np.array(input_image)
22
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
 
 
 
 
 
23
 
24
- # Load models (adjust paths to match your uploaded model files)
25
- normal_model = load_model("models/normal_model.pth", use_cpu)
26
- disp_model = load_model("models/displacement_model.pth", use_cpu)
27
- rough_model = load_model("models/roughness_model.pth", use_cpu)
28
 
29
- # Placeholder processing logic (replace with actual generate.py logic)
30
- # For example, apply the model to the image
31
  with torch.no_grad():
32
- # Convert image to tensor
33
- img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
34
- if not use_cpu and torch.cuda.is_available():
35
- img_tensor = img_tensor.cuda()
36
 
37
- # Generate maps (simplified example)
38
- normal_map = normal_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze().transpose(1, 2, 0)
39
- disp_map = disp_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze().transpose(1, 2, 0)
40
- rough_map = rough_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze().transpose(1, 2, 0)
 
 
 
41
 
42
- # Convert back to uint8 for display
43
- normal_map = (normal_map * 255).astype(np.uint8)
44
- disp_map = (disp_map * 255).astype(np.uint8)
45
- rough_map = (rough_map * 255).astype(np.uint8)
 
 
 
46
 
47
- # Convert to RGB for Gradio output
48
- normal_map = cv2.cvtColor(normal_map, cv2.COLOR_BGR2RGB)
49
- disp_map = cv2.cvtColor(disp_map, cv2.COLOR_BGR2RGB)
50
- rough_map = cv2.cvtColor(rough_map, cv2.COLOR_BGR2RGB)
51
 
52
- return normal_map, disp_map, rough_map
 
 
 
 
53
 
54
- # Gradio interface
 
 
 
 
 
 
 
55
  def generate_maps(input_image, tile_size, seamless, use_cpu):
56
- normal_map, disp_map, rough_map = process_image(input_image, tile_size, seamless, use_cpu)
 
 
 
 
 
 
 
 
 
 
 
 
57
  return input_image, normal_map, disp_map, rough_map
58
 
 
59
  interface = gr.Interface(
60
  fn=generate_maps,
61
  inputs=[
@@ -71,7 +109,7 @@ interface = gr.Interface(
71
  gr.Image(type="numpy", label="Roughness Map"),
72
  ],
73
  title="Material Map Generator",
74
- description="Upload a diffuse texture to generate AI-generated Normal, Displacement, and Roughness maps."
75
  )
76
 
77
  if __name__ == "__main__":
 
5
  import os
6
  from pathlib import Path
7
 
8
+ # Ensure directories exist
9
+ INPUT_DIR = Path("input")
10
+ OUTPUT_DIR = Path("output")
11
+ MODELS_DIR = Path("models")
12
+ INPUT_DIR.mkdir(exist_ok=True)
13
+ OUTPUT_DIR.mkdir(exist_ok=True)
14
+
15
+ # Load pre-trained models
16
  def load_model(model_path, use_cpu=False):
17
+ model = torch.load(model_path, map_location="cpu" if use_cpu else None)
 
 
18
  if not use_cpu and torch.cuda.is_available():
19
  model = model.cuda()
20
  model.eval()
21
  return model
22
 
23
+ # Process image and save to output folder
24
+ def process_image(input_path, tile_size=512, seamless=False, use_cpu=False):
25
+ # Read input image
26
+ img = cv2.imread(str(input_path))
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
28
+
29
+ # Load models
30
+ normal_model = load_model(MODELS_DIR / "NormalMapGenerator-CX-Lite_200000_G.pth", use_cpu)
31
+ franken_model = load_model(MODELS_DIR / "frankenMapGenerator-CX-Lite_215000_G.pth", use_cpu)
32
 
33
+ # Convert to tensor
34
+ img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0
35
+ if not use_cpu and torch.cuda.is_available():
36
+ img_tensor = img_tensor.cuda()
37
 
38
+ # Generate maps
 
39
  with torch.no_grad():
40
+ # Normal map
41
+ normal_map = normal_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze()
42
+ # Franken map (contains Displacement and Roughness)
43
+ franken_map = franken_model(img_tensor.unsqueeze(0)).cpu().numpy().squeeze()
44
 
45
+ # Post-process maps
46
+ # Normal map (RGB)
47
+ if normal_map.ndim == 3:
48
+ normal_map = normal_map.transpose(1, 2, 0)
49
+ else:
50
+ normal_map = np.stack([normal_map] * 3, axis=-1) # Convert grayscale to RGB if needed
51
+ normal_map = (normal_map * 255).clip(0, 255).astype(np.uint8)
52
 
53
+ # Franken map: Extract Displacement (red) and Roughness (green)
54
+ if franken_map.ndim == 3:
55
+ franken_map = franken_map.transpose(1, 2, 0)
56
+ # Displacement map (red channel, grayscale)
57
+ disp_map = franken_map[:, :, 0] # Red channel
58
+ disp_map = (disp_map * 255).clip(0, 255).astype(np.uint8)
59
+ disp_map = np.stack([disp_map] * 3, axis=-1) # Convert to RGB for Gradio display
60
 
61
+ # Roughness map (green channel, grayscale)
62
+ rough_map = franken_map[:, :, 1] # Green channel
63
+ rough_map = (rough_map * 255).clip(0, 255).astype(np.uint8)
64
+ rough_map = np.stack([rough_map] * 3, axis=-1) # Convert to RGB for Gradio display
65
 
66
+ # Define output paths
67
+ base_name = input_path.stem
68
+ normal_path = OUTPUT_DIR / f"{base_name}_normal.png"
69
+ disp_path = OUTPUT_DIR / f"{base_name}_displacement.png"
70
+ rough_path = OUTPUT_DIR / f"{base_name}_roughness.png"
71
 
72
+ # Save outputs
73
+ cv2.imwrite(str(normal_path), cv2.cvtColor(normal_map, cv2.COLOR_RGB2BGR))
74
+ cv2.imwrite(str(disp_path), cv2.cvtColor(disp_map, cv2.COLOR_RGB2BGR))
75
+ cv2.imwrite(str(rough_path), cv2.cvtColor(rough_map, cv2.COLOR_RGB2BGR))
76
+
77
+ return normal_path, disp_path, rough_path
78
+
79
+ # Gradio function
80
  def generate_maps(input_image, tile_size, seamless, use_cpu):
81
+ # Save uploaded image to input folder
82
+ input_path = INPUT_DIR / "uploaded_texture.png"
83
+ input_img = np.array(input_image)
84
+ cv2.imwrite(str(input_path), cv2.cvtColor(input_img, cv2.COLOR_RGB2BGR))
85
+
86
+ # Process image
87
+ normal_path, disp_path, rough_path = process_image(input_path, tile_size, seamless, use_cpu)
88
+
89
+ # Read outputs for display
90
+ normal_map = cv2.cvtColor(cv2.imread(str(normal_path)), cv2.COLOR_BGR2RGB)
91
+ disp_map = cv2.cvtColor(cv2.imread(str(disp_path)), cv2.COLOR_BGR2RGB)
92
+ rough_map = cv2.cvtColor(cv2.imread(str(rough_path)), cv2.COLOR_BGR2RGB)
93
+
94
  return input_image, normal_map, disp_map, rough_map
95
 
96
+ # Gradio interface
97
  interface = gr.Interface(
98
  fn=generate_maps,
99
  inputs=[
 
109
  gr.Image(type="numpy", label="Roughness Map"),
110
  ],
111
  title="Material Map Generator",
112
+ description="Upload a diffuse texture to generate Normal, Displacement, and Roughness maps."
113
  )
114
 
115
  if __name__ == "__main__":