CarolineM5 commited on
Commit
51f3d5f
·
verified ·
1 Parent(s): e5dea97

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. inference.py +2 -2
app.py CHANGED
@@ -71,7 +71,7 @@ pipe = pipe.to(torch.float32).to(device)
71
 
72
 
73
  # --- 3) FONCTION GRADIO D’INTERFACE ---
74
- def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Image:
75
  """
76
  Cette fonction sera appelée à chaque upload par Gradio.
77
  Elle doit retourner une PIL.Image (ou un chemin vers l’image enregistrée).
@@ -80,7 +80,7 @@ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Im
80
  fibers_map = fibers_map.convert("RGB")
81
  rings_map = rings_map.convert("RGB")
82
 
83
- result_img = inference(pipe, device, rings_map, fibers_map)
84
 
85
  return result_img
86
 
@@ -89,7 +89,8 @@ iface = gr.Interface(
89
  fn=gradio_generate,
90
  inputs=[
91
  gr.Image(type="pil", label="Fibre orientation map"),
92
- gr.Image(type="pil", label="Growth ring map")
 
93
  ],
94
  outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
95
  title="Photorealistic wood generator",
@@ -97,6 +98,9 @@ iface = gr.Interface(
97
  Upload :
98
  1) a fibre orientation map,
99
  2) a growth ring map.
 
 
 
100
 
101
  The model will return a photo-realistic rendering of the wood that you can download.
102
  """
 
71
 
72
 
73
  # --- 3) FONCTION GRADIO D’INTERFACE ---
74
+ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image, num_steps): # -> Image.Image
75
  """
76
  Cette fonction sera appelée à chaque upload par Gradio.
77
  Elle doit retourner une PIL.Image (ou un chemin vers l’image enregistrée).
 
80
  fibers_map = fibers_map.convert("RGB")
81
  rings_map = rings_map.convert("RGB")
82
 
83
+ result_img = inference(pipe, device, rings_map, fibers_map, num_steps)
84
 
85
  return result_img
86
 
 
89
  fn=gradio_generate,
90
  inputs=[
91
  gr.Image(type="pil", label="Fibre orientation map"),
92
+ gr.Image(type="pil", label="Growth ring map"),
93
+ gr.Number(value=20, label="Number of inference steps")
94
  ],
95
  outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
96
  title="Photorealistic wood generator",
 
98
  Upload :
99
  1) a fibre orientation map,
100
  2) a growth ring map.
101
+
102
+ Set the number of inference steps.
103
+ Higher values can improve quality but increase processing time.
104
 
105
  The model will return a photo-realistic rendering of the wood that you can download.
106
  """
inference.py CHANGED
@@ -40,7 +40,7 @@ class UNetNoCondWrapper(nn.Module):
40
  # délègue à la vraie instance UNet2DModel
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
- def inference(pipe,device, img1, img2):
44
 
45
  generator = torch.Generator("cpu").manual_seed(0)
46
 
@@ -63,7 +63,7 @@ def inference(pipe,device, img1, img2):
63
  image = PIL.Image.fromarray(img_np)
64
  image = PIL.ImageOps.exif_transpose(image)
65
 
66
- num_inference_steps = 5
67
  image_guidance_scale = 1.9
68
  guidance_scale = 10
69
 
 
40
  # délègue à la vraie instance UNet2DModel
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
+ def inference(pipe,device, img1, img2, num_steps):
44
 
45
  generator = torch.Generator("cpu").manual_seed(0)
46
 
 
63
  image = PIL.Image.fromarray(img_np)
64
  image = PIL.ImageOps.exif_transpose(image)
65
 
66
+ num_inference_steps = num_steps
67
  image_guidance_scale = 1.9
68
  guidance_scale = 10
69