DiZH797 commited on
Commit
79a32c6
·
verified ·
1 Parent(s): ddbfbb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -45,15 +45,19 @@ def get_pipe(model_id: str, lora_scale: float = 1.0):
45
  if cache_key in PIPE_CACHE:
46
  return PIPE_CACHE[cache_key]
47
 
 
48
  # Check if the selected model is the LoRA adapter
49
  if model_id == LORA_MODEL_ID:
50
- # Load the base model for LoRA
51
  pipe = DiffusionPipeline.from_pretrained(
52
  BASE_MODEL_FOR_LORA,
53
  torch_dtype=torch_dtype
54
  ).to(device)
55
- # Load and merge the LoRA weights with the specified scale
56
- pipe.load_lora_weights(LORA_MODEL_ID)
 
 
 
57
  pipe.fuse_lora(lora_scale=lora_scale)
58
  else:
59
  # Load a standard model without LoRA
@@ -77,6 +81,7 @@ def infer(
77
  guidance_scale: float = 7.0,
78
  num_inference_steps: int = 20,
79
  scheduler_name: Optional[str] = None,
 
80
  progress=gr.Progress(track_tqdm=True),
81
  ):
82
  # получаем/загружаем нужный pipe
 
45
  if cache_key in PIPE_CACHE:
46
  return PIPE_CACHE[cache_key]
47
 
48
+
49
  # Check if the selected model is the LoRA adapter
50
  if model_id == LORA_MODEL_ID:
51
+ # Укажите правильные имена файлов
52
  pipe = DiffusionPipeline.from_pretrained(
53
  BASE_MODEL_FOR_LORA,
54
  torch_dtype=torch_dtype
55
  ).to(device)
56
+ pipe.load_lora_weights(
57
+ LORA_MODEL_ID,
58
+ weight_name=["adapter_model_unet.safetensors", "adapter_model_text_encoder.safetensors"] # Замените на ваши имена файлов
59
+ )
60
+ # Применяем масштаб LoRA
61
  pipe.fuse_lora(lora_scale=lora_scale)
62
  else:
63
  # Load a standard model without LoRA
 
81
  guidance_scale: float = 7.0,
82
  num_inference_steps: int = 20,
83
  scheduler_name: Optional[str] = None,
84
+ lora_scale: float = 1.0,
85
  progress=gr.Progress(track_tqdm=True),
86
  ):
87
  # получаем/загружаем нужный pipe