Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import cached_download, hf_hub_url | |
| import os | |
| def display_image(image): | |
| """ | |
| Replace this with your actual image display logic. | |
| """ | |
| image.show() | |
| def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name): | |
| try: | |
| pipe = DiffusionPipeline.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16, | |
| scheduler=DPMSolverMultistepScheduler.from_config( | |
| pipe.scheduler.config), | |
| variant="fp16", | |
| use_safetensors=True, | |
| ).to("cuda") | |
| lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name) | |
| lora_path = cached_download(lora_url) | |
| with tqdm(desc="Loading LoRA weights", unit="step") as pbar: | |
| pipe.load_lora_weights( | |
| lora_path, | |
| weight_name=lora_weight_name, | |
| adapter_name=lora_adapter_name, | |
| progress_callback=lambda step, max_steps: pbar.update(1) | |
| ) | |
| print("LoRA merged successfully!") | |
| return pipe | |
| except Exception as e: | |
| print(f"Error merging LoRA: {e}") | |
| return None | |
| def save_merged_model(pipe, save_path): | |
| """Saves the merged model to the specified path.""" | |
| try: | |
| pipe.save_pretrained(save_path) | |
| print(f"Merged model saved successfully to: {save_path}") | |
| except Exception as e: | |
| print(f"Error saving the merged model: {e}") | |
| if __name__ == "__main__": | |
| base_model_id = input("Enter the base model ID: ") | |
| lora_id = input("Enter the LoRA Hugging Face Hub ID: ") | |
| lora_weight_name = input("Enter the LoRA weight file name: ") | |
| lora_adapter_name = input("Enter the LoRA adapter name: ") | |
| pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name) | |
| if pipe: | |
| prompt = input("Enter your prompt: ") | |
| lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): ")) | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=30, | |
| cross_attention_kwargs={"scale": lora_scale}, | |
| generator=torch.manual_seed(0) | |
| ).images[0] | |
| display_image(image) | |
| # Ask the user for a directory to save the model | |
| save_path = input( | |
| "Enter the directory where you want to save the merged model: " | |
| ) | |
| save_merged_model(pipe, save_path) |