Spaces:
Sleeping
Sleeping
| # Copyright 2022 Roblox Corporation | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from diffusers import DiTPipeline, DPMSolverMultistepScheduler | |
| from SmoothCache import DiffuserCalibrationHelper | |
| def main(): | |
| # Load pipeline | |
| pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to("cuda") | |
| num_inference_steps = 50 | |
| # Initialize calibration helper | |
| calibration_helper = DiffuserCalibrationHelper( | |
| model=pipe.transformer, | |
| calibration_lookahead=3, | |
| calibration_threshold=0.15, | |
| schedule_length=num_inference_steps, # should be consistent with num_inference_steps below | |
| log_file="smoothcache_schedules/diffuser_schedule.json" | |
| ) | |
| # Enable calibration | |
| calibration_helper.enable() | |
| # Run pipeline normally | |
| words = ["Labrador retriever", "combination lock", "cassette player"] | |
| class_ids = pipe.get_label_ids(words) | |
| generator = torch.manual_seed(33) | |
| images = pipe( | |
| class_labels=class_ids, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator | |
| ).images # Normal pipeline call | |
| # Disable calibration and generate schedule | |
| calibration_helper.disable() | |
| print("Calibration complete. Schedule saved to smoothcache_schedules/diffuser_schedule.json") | |
| for prompt, image in zip(words, images): | |
| image.save(prompt + '.png') | |
| if __name__ == "__main__": | |
| main() | |