Tech-Meld commited on
Commit
a30a9fe
·
verified ·
1 Parent(s): 11bdf09

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -111
main.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
- from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
3
- from tqdm.auto import tqdm
4
- from huggingface_hub import hf_hub_url, login, HfApi, create_repo
5
- import os
6
- import traceback
7
- from peft import PeftModel
8
-
9
- def display_image(image):
10
- """Replace this with your actual image display logic."""
11
- image.show()
12
-
13
- def load_and_merge_lora(base_model_id, lora_id, lora_adapter_name):
14
- try:
15
- pipe = DiffusionPipeline.from_pretrained(
16
- base_model_id,
17
- torch_dtype=torch.float16,
18
- variant="fp16",
19
- use_safetensors=True,
20
- ).to("cuda")
21
-
22
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(
23
- pipe.scheduler.config
24
- )
25
-
26
- # Get the UNet model from the pipeline
27
- unet = pipe.unet
28
-
29
- # Apply PEFT to the UNet model
30
- unet = PeftModel.from_pretrained(
31
- unet,
32
- lora_id,
33
- torch_dtype=torch.float16,
34
- adapter_name=lora_adapter_name
35
- )
36
-
37
- # Replace the original UNet in the pipeline with the PEFT-loaded one
38
- pipe.unet = unet
39
-
40
- print("LoRA merged successfully!")
41
- return pipe
42
-
43
- except Exception as e:
44
- error_msg = traceback.format_exc()
45
- print(f"Error merging LoRA: {e}\n\nFull traceback saved to errors.txt")
46
-
47
- with open("errors.txt", "w") as f:
48
- f.write(error_msg)
49
-
50
- return None
51
-
52
- def save_merged_model(pipe, save_path, push_to_hub=False, hf_token=None):
53
- """Saves and optionally pushes the merged model to Hugging Face Hub."""
54
- try:
55
- pipe.save_pretrained(save_path)
56
- print(f"Merged model saved successfully to: {save_path}")
57
-
58
- if push_to_hub:
59
- if hf_token is None:
60
- hf_token = input("Enter your Hugging Face write token: ")
61
- login(token=hf_token)
62
-
63
- repo_name = input("Enter the Hugging Face repository name "
64
- "(e.g., your_username/your_model_name): ")
65
-
66
- # Create the repository if it doesn't exist
67
- create_repo(repo_name, token=hf_token, exist_ok=True)
68
-
69
- api = HfApi()
70
- api.upload_folder(
71
- folder_path=save_path,
72
- repo_id=repo_name,
73
- token=hf_token,
74
- repo_type="model",
75
- )
76
- print(f"Model pushed successfully to Hugging Face Hub: {repo_name}")
77
-
78
- except Exception as e:
79
- print(f"Error saving/pushing the merged model: {e}")
80
-
81
- if __name__ == "__main__":
82
- base_model_id = input("Enter the base model ID: ")
83
- lora_id = "Tech-Meld/life-fx"
84
- lora_adapter_name = input("Enter the LoRA adapter name: ")
85
-
86
- pipe = load_and_merge_lora(base_model_id, lora_id, lora_adapter_name)
87
-
88
- if pipe:
89
- prompt = input("Enter your prompt: ")
90
- lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): "))
91
-
92
- image = pipe(
93
- prompt,
94
- num_inference_steps=30,
95
- cross_attention_kwargs={"scale": lora_scale},
96
- generator=torch.manual_seed(0)
97
- ).images[0]
98
-
99
- # Display the image (make sure you have a display function)
100
- display_image(image)
101
-
102
- # Save the image
103
- image.save("generated_image.png")
104
- print(f"Image saved to: generated_image.png")
105
-
106
- save_path = input("Enter the directory where you want to save the merged model: ")
107
-
108
- push_to_hub = input("Do you want to push the model to Hugging Face Hub? (yes/no): ")
109
- push_to_hub = push_to_hub.lower() == "yes"
110
-
111
- save_merged_model(pipe, save_path, push_to_hub=push_to_hub)