nsave commited on
Commit
7c7cc20
·
verified ·
1 Parent(s): 51128ae

Create config.py

Browse files
Files changed (1) hide show
  1. config.py +50 -0
config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Literal
3
+
4
+ import torch
5
+ import os
6
+
7
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "False") == "True"
8
+
9
+
10
+ @dataclass
11
+ class Config:
12
+ """
13
+ The configuration for the API.
14
+ """
15
+
16
+ ####################################################################
17
+ # Server
18
+ ####################################################################
19
+ # In most cases, you should leave this as it is.
20
+ host: str = "127.0.0.1"
21
+ port: int = 9090
22
+ workers: int = 1
23
+
24
+ ####################################################################
25
+ # Model configuration
26
+ ####################################################################
27
+ mode: Literal["txt2img", "img2img"] = "txt2img"
28
+ # SD1.x variant model
29
+ model_id_or_path: str = os.environ.get("MODEL", "KBlueLeaf/kohaku-v2.1")
30
+ # LoRA dictionary write like field(default_factory=lambda: {'E:/stable-diffusion-webui/models/Lora_1.safetensors' : 1.0 , 'E:/stable-diffusion-webui/models/Lora_2.safetensors' : 0.2})
31
+ lora_dict: dict = None
32
+ # LCM-LORA model
33
+ lcm_lora_id: str = os.environ.get("LORA", "latent-consistency/lcm-lora-sdv1-5")
34
+ # TinyVAE model
35
+ vae_id: str = os.environ.get("VAE", "madebyollin/taesd")
36
+ # Device to use
37
+ device: torch.device = torch.device("cuda")
38
+ # Data type
39
+ dtype: torch.dtype = torch.float16
40
+ # acceleration
41
+ acceleration: Literal["none", "xformers", "tensorrt"] = "xformers"
42
+
43
+ ####################################################################
44
+ # Inference configuration
45
+ ####################################################################
46
+ # Number of inference steps
47
+ t_index_list: List[int] = field(default_factory=lambda: [0, 16, 32, 45])
48
+ # Number of warmup steps
49
+ warmup: int = 10
50
+ use_safety_checker: bool = SAFETY_CHECKER