mboss commited on
Commit
97665a5
·
1 Parent(s): 7349148

Updates for spaces

Browse files
app.py CHANGED
@@ -15,11 +15,9 @@ with gr.Blocks() as demo:
15
  ReSWD is a method for distribution matching with reduced variance.
16
  """
17
  )
18
- fabric = L.Fabric(devices=1, accelerator="auto", precision="16-mixed")
19
-
20
  with gr.Tab("SW Guidance (SD 3.5 Large Turbo)"):
21
- create_sw_guidance(fabric, "stabilityai/stable-diffusion-3.5-large-turbo")
22
  with gr.Tab("Color Matching"):
23
- create_color_matching(fabric)
24
 
25
  demo.launch()
 
15
  ReSWD is a method for distribution matching with reduced variance.
16
  """
17
  )
 
 
18
  with gr.Tab("SW Guidance (SD 3.5 Large Turbo)"):
19
+ create_sw_guidance("stabilityai/stable-diffusion-3.5-large-turbo")
20
  with gr.Tab("Color Matching"):
21
+ create_color_matching()
22
 
23
  demo.launch()
src/color_matcher.py CHANGED
@@ -61,7 +61,6 @@ class CDL(torch.nn.Module):
61
 
62
 
63
  def train(
64
- fabric: L.Fabric,
65
  criteria: AbstractLoss,
66
  source_img: Float[torch.Tensor, "B C H W"],
67
  target_img: Float[torch.Tensor, "B C H W"],
@@ -71,25 +70,25 @@ def train(
71
  silent: bool = False,
72
  write_video_animation_path: Optional[str] = None,
73
  ) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]:
74
- criteria = fabric.setup(criteria)
75
 
76
  source_max_res = Resize(match_resolution, antialias=True)(source_img)
77
  target_max_res = Resize(match_resolution, antialias=True)(target_img)
78
 
79
  target_cielab = (
80
- fabric.to_device(rgb_to_lab(target_max_res).permute(0, 3, 1, 2))
81
  .permute(0, 2, 3, 1)
82
  .contiguous()
83
  )
84
 
85
- source_max_res = fabric.to_device(source_max_res)
86
- source_img = fabric.to_device(source_img)
87
 
88
  batch_size = source_img.shape[0]
89
  cdl = CDL(batch_size)
90
 
91
  optim = torch.optim.Adam(cdl.parameters(), lr=lr)
92
- cdl, optim = fabric.setup(cdl, optim)
93
 
94
  lossses = []
95
  for i in tqdm(range(num_steps), disable=silent):
@@ -106,7 +105,7 @@ def train(
106
  i,
107
  )
108
 
109
- fabric.backward(loss)
110
  optim.step()
111
 
112
  lossses.append(loss.item())
 
61
 
62
 
63
  def train(
 
64
  criteria: AbstractLoss,
65
  source_img: Float[torch.Tensor, "B C H W"],
66
  target_img: Float[torch.Tensor, "B C H W"],
 
70
  silent: bool = False,
71
  write_video_animation_path: Optional[str] = None,
72
  ) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]:
73
+ criteria = criteria.cuda()
74
 
75
  source_max_res = Resize(match_resolution, antialias=True)(source_img)
76
  target_max_res = Resize(match_resolution, antialias=True)(target_img)
77
 
78
  target_cielab = (
79
+ rgb_to_lab(target_max_res).cuda().permute(0, 3, 1, 2)
80
  .permute(0, 2, 3, 1)
81
  .contiguous()
82
  )
83
 
84
+ source_max_res = source_max_res.cuda()
85
+ source_img = source_img.cuda()
86
 
87
  batch_size = source_img.shape[0]
88
  cdl = CDL(batch_size)
89
 
90
  optim = torch.optim.Adam(cdl.parameters(), lr=lr)
91
+ cdl, optim = cdl.cuda(), optim.cuda()
92
 
93
  lossses = []
94
  for i in tqdm(range(num_steps), disable=silent):
 
105
  i,
106
  )
107
 
108
+ loss.backward()
109
  optim.step()
110
 
111
  lossses.append(loss.item())
src/gradio_demo/color_matching.py CHANGED
@@ -10,7 +10,7 @@ from src.loss import VectorSWDLoss
10
  from src.utils.image import from_torch, to_torch
11
 
12
 
13
- def create_color_matching(fabric: L.Fabric):
14
  """
15
  Creates the Gradio interface for color matching between source and target images.
16
  """
 
10
  from src.utils.image import from_torch, to_torch
11
 
12
 
13
+ def create_color_matching():
14
  """
15
  Creates the Gradio interface for color matching between source and target images.
16
  """
src/gradio_demo/sw_guidance.py CHANGED
@@ -44,7 +44,7 @@ def log_slider_to_lr(log_lr):
44
 
45
 
46
  def create_sw_guidance(
47
- fabric: L.Fabric, model_name: str = "stabilityai/stable-diffusion-3.5-large"
48
  ):
49
  """
50
  Creates the Gradio interface for SW guidance with SD3.5.
@@ -63,7 +63,7 @@ def create_sw_guidance(
63
  pipe = create_pipeline(
64
  model_name,
65
  device="cuda",
66
- compile=True,
67
  )
68
 
69
  model_config = models[model_name]
 
44
 
45
 
46
  def create_sw_guidance(
47
+ model_name: str = "stabilityai/stable-diffusion-3.5-large"
48
  ):
49
  """
50
  Creates the Gradio interface for SW guidance with SD3.5.
 
63
  pipe = create_pipeline(
64
  model_name,
65
  device="cuda",
66
+ compile=False,
67
  )
68
 
69
  model_config = models[model_name]