ixarchakos commited on
Commit
a8bc052
·
verified ·
1 Parent(s): 6a18796

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +63 -13
  2. config.py +8 -3
app.py CHANGED
@@ -18,18 +18,61 @@ pipe2 = AutoPipelineForText2Image.from_pretrained(
18
 
19
  @spaces.GPU
20
  def infer(
21
- prompt,
22
- negative_prompt,
23
- seed,
24
- randomize_seed,
25
- width,
26
- height,
27
- guidance_scale,
28
- num_inference_steps,
29
- model_name,
30
- checkpoint,
31
  progress=gr.Progress(track_tqdm=True),
32
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  pipe = pipe2 if model_name == 'runwayml/stable-diffusion-v1-5' else pipe1
34
  pipe.load_lora_weights(os.path.join('checkpoints', checkpoint),
35
  weight_name="pytorch_lora_weights.safetensors")
@@ -52,10 +95,17 @@ def infer(
52
 
53
 
54
  def main():
 
 
 
 
 
 
 
55
  examples = [
56
- "Ioannis hiking in the mountains",
57
  "Ioannis at an office desk",
58
- "Ioannis skiing in the Alps",
59
  ]
60
 
61
  css = """
@@ -140,7 +190,7 @@ def main():
140
  minimum=1,
141
  maximum=400,
142
  step=1,
143
- value=150, # Replace with defaults that work for your checkpoints
144
  )
145
 
146
  with gr.Row():
 
18
 
19
  @spaces.GPU
20
  def infer(
21
+ prompt: str,
22
+ negative_prompt: str,
23
+ seed: int,
24
+ randomize_seed: bool,
25
+ width: int,
26
+ height: int,
27
+ guidance_scale: float,
28
+ num_inference_steps: int,
29
+ model_name: str,
30
+ checkpoint: str,
31
  progress=gr.Progress(track_tqdm=True),
32
  ):
33
+ """
34
+ Generate an image using the specified parameters and LoRA checkpoint.
35
+
36
+ This function loads a LoRA checkpoint and generates an image based on the provided
37
+ prompt and generation parameters. It supports both Stable Diffusion v1.5 and SDXL models.
38
+
39
+ Parameters
40
+ ----------
41
+ prompt : str
42
+ Text prompt describing the image to generate
43
+ negative_prompt : str
44
+ Text prompt describing what to avoid in the generated image
45
+ seed : int
46
+ Random seed for reproducible generation
47
+ randomize_seed : bool
48
+ Whether to use a random seed instead of the provided one
49
+ width : int
50
+ Width of the generated image
51
+ height : int
52
+ Height of the generated image
53
+ guidance_scale : float
54
+ Strength of prompt guidance (higher values = stronger adherence to prompt)
55
+ num_inference_steps : int
56
+ Number of denoising steps for generation
57
+ model_name : str
58
+ Name of the base model to use
59
+ checkpoint : str
60
+ Name of the LoRA checkpoint to load
61
+ progress : gr.Progress, optional
62
+ Gradio progress tracker
63
+
64
+ Returns
65
+ -------
66
+ tuple
67
+ Tuple containing (generated_image, seed_used)
68
+
69
+ Notes
70
+ -----
71
+ - Automatically selects the appropriate pipeline based on model_name
72
+ - Loads LoRA weights from the checkpoints directory
73
+ - Uses torch.Generator for reproducible results
74
+ - Returns both the generated image and the seed used
75
+ """
76
  pipe = pipe2 if model_name == 'runwayml/stable-diffusion-v1-5' else pipe1
77
  pipe.load_lora_weights(os.path.join('checkpoints', checkpoint),
78
  weight_name="pytorch_lora_weights.safetensors")
 
95
 
96
 
97
  def main():
98
+ """
99
+ Create and configure the Gradio web interface.
100
+
101
+ This function sets up the complete Gradio interface with all necessary components
102
+ for image generation, including input controls, output display, and example prompts.
103
+
104
+ """
105
  examples = [
106
+ "Ioannis hiking in the mountains looking at the camera",
107
  "Ioannis at an office desk",
108
+ "Ioannis skiing in the Alps looking at the camera without mask",
109
  ]
110
 
111
  css = """
 
190
  minimum=1,
191
  maximum=400,
192
  step=1,
193
+ value=200, # Replace with defaults that work for your checkpoints
194
  )
195
 
196
  with gr.Row():
config.py CHANGED
@@ -2,17 +2,22 @@ import torch
2
  import numpy as np
3
  import os
4
 
5
-
6
  device = "cuda" if torch.cuda.is_available() else "mps"
7
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
8
 
 
9
  MAX_SEED = np.iinfo(np.int32).max
10
  MAX_IMAGE_SIZE = 1024
11
 
12
- PROJECT_PATH = "/Users/ioannisxarchakos/Dropbox/Projects/Dyla/"
 
 
 
13
  PRETRAINED_MODEL = 'runwayml/stable-diffusion-v1-5'
14
  PRETRAINED_MODEL_XL = 'stabilityai/stable-diffusion-xl-base-1.0'
15
  LORA_WEIGHTS_PATH = os.path.join(PROJECT_PATH, 'sd-checkpoints-finetuned-lora/')
16
- DATA_DIR = os.path.join(PROJECT_PATH, 'ioannis_images/')
 
17
  CSV_PATH = os.path.join(DATA_DIR, 'prompts.csv')
18
  OUTPUT_DIR = 'generated_images'
 
2
  import numpy as np
3
  import os
4
 
5
+ # Device configuration
6
  device = "cuda" if torch.cuda.is_available() else "mps"
7
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
8
 
9
+ # Constants
10
  MAX_SEED = np.iinfo(np.int32).max
11
  MAX_IMAGE_SIZE = 1024
12
 
13
+ # Project paths
14
+ PROJECT_PATH = ""
15
+
16
+ # Model configurations
17
  PRETRAINED_MODEL = 'runwayml/stable-diffusion-v1-5'
18
  PRETRAINED_MODEL_XL = 'stabilityai/stable-diffusion-xl-base-1.0'
19
  LORA_WEIGHTS_PATH = os.path.join(PROJECT_PATH, 'sd-checkpoints-finetuned-lora/')
20
+ IMAGE_FOLDER = "ioannis_images"
21
+ DATA_DIR = os.path.join(PROJECT_PATH, IMAGE_FOLDER)
22
  CSV_PATH = os.path.join(DATA_DIR, 'prompts.csv')
23
  OUTPUT_DIR = 'generated_images'