MPS incompatibility: float64 hard-coded in Euler–Maruyama scheduler

#1
by milkzheng - opened

Hi, thank you for developing and releasing this generative model — it’s a very interesting and useful project.

I tried running the pipeline on a MacBook using the PyTorch MPS backend. However, I encountered an error originating from the scheduler, where float64 is hard-coded when constructing the final timestep tensor (in EulerMaruyamaScheduler.set_timesteps). Since the MPS backend does not support float64, this causes a runtime error when the tensor is moved to the MPS device.

I was wondering whether there is a specific numerical reason for requiring float64 precision at this step of the scheduler. Would it be possible (or safe) to instead use float32, or inherit the dtype from the existing timesteps tensor (as done in some diffusers schedulers), to improve compatibility with MPS?

Thanks again for your work, and I’d be happy to help test a fix if useful.

Owkin x Bioptimus org

Hi !

Sorry for the delay in responding, I only just saw your comment. Thank you for trying out the model and posting your feedback !

I am aware of this float64 situation. It originally comes from the REPA-E repository we used to train CytoSyn. As our main results were obtained with it, we decided to keep it here to ensure reproducibility. Nonetheless, I believe that you can use float32 without noticeably altering the results. One quick way to be confident about this would be to run the pipeline twice with the same conditioning vector, once on CPU with float64, once on the MPS device with float32 and compare the generated images: small discrepancies can be linked to the CPU/GPU computation but anything visually significant would be caused by the dtype change. I will look deeper into it in the coming days.

Sign up or log in to comment