ishworrsubedii commited on
Commit
45ac234
·
1 Parent(s): 070b382

refactor: remove fooocus api

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fooocus_api_version.py +0 -1
  2. fooocusapi/api.py +0 -41
  3. fooocusapi/args.py +0 -20
  4. fooocusapi/base_args.py +0 -27
  5. fooocusapi/configs/default.py +0 -92
  6. fooocusapi/models/common/base.py +0 -189
  7. fooocusapi/models/common/image_meta.py +0 -118
  8. fooocusapi/models/common/requests.py +0 -132
  9. fooocusapi/models/common/response.py +0 -90
  10. fooocusapi/models/common/task.py +0 -60
  11. fooocusapi/models/requests_v1.py +0 -274
  12. fooocusapi/models/requests_v2.py +0 -50
  13. fooocusapi/parameters.py +0 -94
  14. fooocusapi/routes/__init__.py +0 -0
  15. fooocusapi/routes/generate_v1.py +0 -186
  16. fooocusapi/routes/generate_v2.py +0 -199
  17. fooocusapi/routes/query.py +0 -135
  18. fooocusapi/sql_client.py +0 -269
  19. fooocusapi/task_queue.py +0 -323
  20. fooocusapi/utils/api_utils.py +0 -291
  21. fooocusapi/utils/call_worker.py +0 -97
  22. fooocusapi/utils/file_utils.py +0 -143
  23. fooocusapi/utils/img_utils.py +0 -198
  24. fooocusapi/utils/logger.py +0 -132
  25. fooocusapi/utils/lora_manager.py +0 -71
  26. fooocusapi/utils/model_loader.py +0 -46
  27. fooocusapi/utils/tools.py +0 -159
  28. fooocusapi/worker.py +0 -1044
  29. predict.py +0 -316
  30. repositories/Fooocus/__init__.py +0 -4
  31. repositories/Fooocus/args_manager.py +0 -55
  32. repositories/Fooocus/extras/BLIP/configs/bert_config.json +0 -21
  33. repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml +0 -33
  34. repositories/Fooocus/extras/BLIP/configs/med_config.json +0 -21
  35. repositories/Fooocus/extras/BLIP/configs/nlvr.yaml +0 -21
  36. repositories/Fooocus/extras/BLIP/configs/nocaps.yaml +0 -15
  37. repositories/Fooocus/extras/BLIP/configs/pretrain.yaml +0 -27
  38. repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml +0 -34
  39. repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml +0 -34
  40. repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml +0 -12
  41. repositories/Fooocus/extras/BLIP/configs/vqa.yaml +0 -25
  42. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json +0 -23
  43. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json +0 -0
  44. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json +0 -3
  45. repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt +0 -0
  46. repositories/Fooocus/extras/BLIP/models/blip.py +0 -239
  47. repositories/Fooocus/extras/BLIP/models/blip_itm.py +0 -76
  48. repositories/Fooocus/extras/BLIP/models/blip_nlvr.py +0 -105
  49. repositories/Fooocus/extras/BLIP/models/blip_pretrain.py +0 -339
  50. repositories/Fooocus/extras/BLIP/models/blip_retrieval.py +0 -319
fooocus_api_version.py DELETED
@@ -1 +0,0 @@
1
- version = '0.4.1.1'
 
 
fooocusapi/api.py DELETED
@@ -1,41 +0,0 @@
1
- """
2
- Entry for startup fastapi server
3
- """
4
- from fastapi import FastAPI
5
- from fastapi.staticfiles import StaticFiles
6
- from fastapi.middleware.cors import CORSMiddleware
7
-
8
- import uvicorn
9
-
10
- from fooocusapi.utils import file_utils
11
- from fooocusapi.routes.generate_v1 import secure_router as generate_v1
12
- from fooocusapi.routes.generate_v2 import secure_router as generate_v2
13
- from fooocusapi.routes.query import secure_router as query
14
- from mannequin_to_model import secure_router as mannequin_to_model
15
-
16
- app = FastAPI()
17
-
18
- app.add_middleware(
19
- CORSMiddleware,
20
- allow_origins=["*"], # Allow access from all sources
21
- allow_credentials=True,
22
- allow_methods=["*"], # Allow all HTTP methods
23
- allow_headers=["*"], # Allow all request headers
24
- )
25
-
26
- app.mount("/files", StaticFiles(directory=file_utils.output_dir), name="files")
27
-
28
- app.include_router(query)
29
- app.include_router(generate_v1)
30
- app.include_router(generate_v2)
31
- app.include_router(mannequin_to_model)
32
-
33
-
34
- def start_app(args):
35
- """Start the FastAPI application"""
36
- file_utils.STATIC_SERVER_BASE = args.base_url + "/files/"
37
- uvicorn.run(
38
- app="fooocusapi.api:app",
39
- host="0.0.0.0",
40
- port=8000,
41
- log_level=args.log_level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/args.py DELETED
@@ -1,20 +0,0 @@
1
- """
2
- Do not modify the import order
3
- """
4
- from fooocusapi.base_args import add_base_args
5
- import ldm_patched.modules.args_parser as args_parser
6
-
7
- # Add Fooocus-API args to parser
8
- add_base_args(args_parser.parser, False)
9
-
10
- # Apply Fooocus args
11
- from args_manager import args_parser
12
-
13
- # Override the port default value
14
- args_parser.parser.set_defaults(
15
- port=8888
16
- )
17
-
18
- # Execute args parse again
19
- args_parser.args = args_parser.parser.parse_args()
20
- args = args_parser.args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/base_args.py DELETED
@@ -1,27 +0,0 @@
1
- """
2
- base_args.py
3
- """
4
- from argparse import ArgumentParser
5
-
6
-
7
- def add_base_args(parser: ArgumentParser, before_prepared: bool):
8
- """
9
- Add base args for fooocusapi
10
- Args:
11
- parser: ArgumentParser
12
- before_prepared: before prepare environment
13
- Returns:
14
- """
15
- if before_prepared:
16
- parser.add_argument("--port", type=int, default=8888, help="Set the listen port, default: 8888")
17
-
18
- parser.add_argument("--host", type=str, default='127.0.0.1', help="Set the listen host, default: 127.0.0.1")
19
- parser.add_argument("--base-url", type=str, default=None, help="Set base url for outside visit, default is http://host:port")
20
- parser.add_argument("--log-level", type=str, default='info', help="Log info for Uvicorn, default: info")
21
- parser.add_argument("--skip-pip", default=False, action="store_true", help="Skip automatic pip install when setup")
22
- parser.add_argument("--preload-pipeline", default=False, action="store_true", help="Preload pipeline before start http server")
23
- parser.add_argument("--queue-size", type=int, default=100, help="Working queue size, default: 100, generation requests exceeding working queue size will return failure")
24
- parser.add_argument("--queue-history", type=int, default=0, help="Finished jobs reserve size, tasks exceeding the limit will be deleted, including output image files, default: 0, means no limit")
25
- parser.add_argument('--webhook-url', type=str, default=None, help='The URL to send a POST request when a job is finished')
26
- parser.add_argument('--persistent', default=False, action="store_true", help="Store history to db")
27
- parser.add_argument("--apikey", type=str, default=None, help="API key for authenticating requests")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/configs/default.py DELETED
@@ -1,92 +0,0 @@
1
- """
2
- Static variables for Fooocus API
3
- """
4
- img_generate_responses = {
5
- "200": {
6
- "description": "PNG bytes if request's 'Accept' header is 'image/png', otherwise JSON",
7
- "content": {
8
- "application/json": {
9
- "example": [{
10
- "base64": "...very long string...",
11
- "seed": "1050625087",
12
- "finish_reason": "SUCCESS",
13
- }]
14
- },
15
- "application/json async": {
16
- "example": {
17
- "job_id": 1,
18
- "job_type": "Text to Image"
19
- }
20
- },
21
- "image/png": {
22
- "example": "PNG bytes, what did you expect?"
23
- },
24
- },
25
- }
26
- }
27
-
28
- default_inpaint_engine_version = "v2.6"
29
-
30
- default_styles = ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
31
- default_base_model_name = "juggernautXL_v8Rundiffusion.safetensors"
32
- default_refiner_model_name = "None"
33
- default_refiner_switch = 0.5
34
- default_loras = [[True, "sd_xl_offset_example-lora_1.0.safetensors", 0.1]]
35
- default_cfg_scale = 7.0
36
- default_prompt_negative = ""
37
- default_aspect_ratio = "1152*896"
38
- default_sampler = "dpmpp_2m_sde_gpu"
39
- default_scheduler = "karras"
40
-
41
- available_aspect_ratios = [
42
- "704*1408",
43
- "704*1344",
44
- "768*1344",
45
- "768*1280",
46
- "832*1216",
47
- "832*1152",
48
- "896*1152",
49
- "896*1088",
50
- "960*1088",
51
- "960*1024",
52
- "1024*1024",
53
- "1024*960",
54
- "1088*960",
55
- "1088*896",
56
- "1152*896",
57
- "1152*832",
58
- "1216*832",
59
- "1280*768",
60
- "1344*768",
61
- "1344*704",
62
- "1408*704",
63
- "1472*704",
64
- "1536*640",
65
- "1600*640",
66
- "1664*576",
67
- "1728*576",
68
- ]
69
-
70
- uov_methods = [
71
- "Disabled",
72
- "Vary (Subtle)",
73
- "Vary (Strong)",
74
- "Upscale (1.5x)",
75
- "Upscale (2x)",
76
- "Upscale (Fast 2x)",
77
- "Upscale (Custom)",
78
- ]
79
-
80
- outpaint_expansions = ["Left", "Right", "Top", "Bottom"]
81
-
82
-
83
- def get_aspect_ratio_value(label: str) -> str:
84
- """
85
- Get aspect ratio
86
- Args:
87
- label: str, aspect ratio
88
-
89
- Returns:
90
-
91
- """
92
- return label.split(" ")[0].replace("×", "*")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/common/base.py DELETED
@@ -1,189 +0,0 @@
1
- """Common models"""
2
- from typing import List, Tuple
3
- from enum import Enum
4
- from fastapi import UploadFile
5
- from fastapi.exceptions import RequestValidationError
6
- from pydantic import (
7
- ValidationError,
8
- ConfigDict,
9
- BaseModel,
10
- TypeAdapter,
11
- Field
12
- )
13
- from pydantic_core import InitErrorDetails
14
-
15
- from fooocusapi.configs.default import default_loras
16
-
17
-
18
- class PerformanceSelection(str, Enum):
19
- """Performance selection"""
20
- speed = 'Speed'
21
- quality = 'Quality'
22
- extreme_speed = 'Extreme Speed'
23
- lightning = 'Lightning'
24
- hyper_sd = 'Hyper-SD'
25
-
26
-
27
- class Lora(BaseModel):
28
- """Common params lora model"""
29
- enabled: bool
30
- model_name: str
31
- weight: float = Field(default=0.5, ge=-2, le=2)
32
-
33
- model_config = ConfigDict(
34
- protected_namespaces=('protect_me_', 'also_protect_')
35
- )
36
-
37
-
38
- LoraList = TypeAdapter(List[Lora])
39
- default_loras_model = []
40
- for lora in default_loras:
41
- if lora[0] != 'None':
42
- default_loras_model.append(
43
- Lora(
44
- enabled=lora[0],
45
- model_name=lora[1],
46
- weight=lora[2])
47
- )
48
- default_loras_json = LoraList.dump_json(default_loras_model)
49
-
50
-
51
- class UpscaleOrVaryMethod(str, Enum):
52
- """Upscale or Vary method"""
53
- subtle_variation = 'Vary (Subtle)'
54
- strong_variation = 'Vary (Strong)'
55
- upscale_15 = 'Upscale (1.5x)'
56
- upscale_2 = 'Upscale (2x)'
57
- upscale_fast = 'Upscale (Fast 2x)'
58
- upscale_custom = 'Upscale (Custom)'
59
-
60
-
61
- class OutpaintExpansion(str, Enum):
62
- """Outpaint expansion"""
63
- left = 'Left'
64
- right = 'Right'
65
- top = 'Top'
66
- bottom = 'Bottom'
67
-
68
-
69
- class ControlNetType(str, Enum):
70
- """ControlNet Type"""
71
- cn_ip = "ImagePrompt"
72
- cn_ip_face = "FaceSwap"
73
- cn_canny = "PyraCanny"
74
- cn_cpds = "CPDS"
75
-
76
-
77
- class ImagePrompt(BaseModel):
78
- """Common params object ImagePrompt"""
79
- cn_img: UploadFile | None = Field(default=None)
80
- cn_stop: float | None = Field(default=None, ge=0, le=1)
81
- cn_weight: float | None = Field(default=None, ge=0, le=2, description="None for default value")
82
- cn_type: ControlNetType = Field(default=ControlNetType.cn_ip)
83
-
84
-
85
- class DescribeImageType(str, Enum):
86
- """Image type for image to prompt"""
87
- photo = 'Photo'
88
- anime = 'Anime'
89
-
90
-
91
- class ImageMetaScheme(str, Enum):
92
- """Scheme for save image meta
93
- Attributes:
94
- Fooocus: json format
95
- A111: string
96
- """
97
- Fooocus = 'fooocus'
98
- A111 = 'a111'
99
-
100
-
101
- def style_selection_parser(style_selections: str | List[str]) -> List[str]:
102
- """
103
- Parse style selections, Convert to list
104
- Args:
105
- style_selections: str, comma separated Fooocus style selections
106
- e.g. Fooocus V2, Fooocus Enhance, Fooocus Sharp
107
- Returns:
108
- List[str]
109
- """
110
- style_selection_arr: List[str] = []
111
- if style_selections is None or len(style_selections) == 0:
112
- return []
113
- for part in style_selections:
114
- if len(part) > 0:
115
- for s in part.split(','):
116
- style = s.strip()
117
- style_selection_arr.append(style)
118
- return style_selection_arr
119
-
120
-
121
- def lora_parser(loras: str) -> List[Lora]:
122
- """
123
- Parse lora config, Convert to list
124
- Args:
125
- loras: a json string for loras
126
- Returns:
127
- List[Lora]
128
- """
129
- loras_model: List[Lora] = []
130
- if loras is None or len(loras) == 0:
131
- return loras_model
132
- try:
133
- loras_model = LoraList.validate_json(loras)
134
- return loras_model
135
- except ValidationError as ve:
136
- errs = ve.errors()
137
- raise RequestValidationError from errs
138
-
139
-
140
- def outpaint_selections_parser(outpaint_selections: str | list[str]) -> List[OutpaintExpansion]:
141
- """
142
- Parse outpaint selections, Convert to list
143
- Args:
144
- outpaint_selections: str, comma separated Left, Right, Top, Bottom
145
- e.g. Left, Right, Top, Bottom
146
- Returns:
147
- List[OutpaintExpansion]
148
- """
149
- outpaint_selections_arr: List[OutpaintExpansion] = []
150
- if outpaint_selections is None or len(outpaint_selections) == 0:
151
- return []
152
- for part in outpaint_selections:
153
- if len(part) > 0:
154
- for s in part.split(','):
155
- try:
156
- expansion = OutpaintExpansion(s)
157
- outpaint_selections_arr.append(expansion)
158
- except ValueError:
159
- errs = InitErrorDetails(
160
- type='enum',
161
- loc=tuple('outpaint_selections'),
162
- input=outpaint_selections,
163
- ctx={
164
- 'expected': "str, comma separated Left, Right, Top, Bottom"
165
- })
166
- raise RequestValidationError from errs
167
- return outpaint_selections_arr
168
-
169
-
170
- def image_prompt_parser(image_prompts_config: List[Tuple]) -> List[ImagePrompt]:
171
- """
172
- Image prompt parser, Convert to List[ImagePrompt]
173
- Args:
174
- image_prompts_config: List[Tuple]
175
- e.g. ('image1.jpg', 0.5, 1.0, 'normal'), ('image2.jpg', 0.5, 1.0, 'normal')
176
- returns:
177
- List[ImagePrompt]
178
- """
179
- image_prompts: List[ImagePrompt] = []
180
- if image_prompts_config is None or len(image_prompts_config) == 0:
181
- return []
182
- for config in image_prompts_config:
183
- cn_img, cn_stop, cn_weight, cn_type = config
184
- image_prompts.append(ImagePrompt(
185
- cn_img=cn_img,
186
- cn_stop=cn_stop,
187
- cn_weight=cn_weight,
188
- cn_type=cn_type))
189
- return image_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/common/image_meta.py DELETED
@@ -1,118 +0,0 @@
1
- """
2
- Image meta schema
3
- """
4
- from typing import List
5
-
6
- from fooocus_version import version
7
- from pydantic import BaseModel
8
-
9
-
10
- class ImageMeta(BaseModel):
11
- """
12
- Image meta data model
13
- """
14
-
15
- metadata_scheme: str = "fooocus"
16
-
17
- base_model: str
18
- base_model_hash: str
19
-
20
- prompt: str
21
- full_prompt: List[str]
22
- prompt_expansion: str
23
-
24
- negative_prompt: str
25
- full_negative_prompt: List[str]
26
-
27
- performance: str
28
-
29
- style: str
30
-
31
- refiner_model: str = "None"
32
- refiner_switch: float = 0.5
33
-
34
- loras: List[list]
35
-
36
- resolution: str
37
-
38
- sampler: str = "dpmpp_2m_sde_gpu"
39
- scheduler: str = "karras"
40
- seed: str
41
- adm_guidance: str
42
- guidance_scale: float
43
- sharpness: float
44
- steps: int
45
- vae_name: str
46
-
47
- version: str = version
48
-
49
- def __repr__(self):
50
- return ""
51
-
52
-
53
- def loras_parser(loras: list) -> list:
54
- """
55
- Parse lora list
56
- """
57
- return [
58
- [
59
- lora[0].rsplit('.', maxsplit=1)[:1][0],
60
- lora[1],
61
- "hash_not_calculated",
62
- ] for lora in loras if lora[0] != 'None' and lora[0] is not None]
63
-
64
-
65
- def image_parse(
66
- async_tak: object,
67
- task: dict
68
- ) -> dict | str:
69
- """
70
- Parse image meta data
71
- Generate meta data for image from task and async task object
72
- Args:
73
- async_tak: async task obj
74
- task: task obj
75
-
76
- Returns:
77
- dict: image meta data
78
- """
79
- req_param = async_tak.req_param
80
- meta = ImageMeta(
81
- metadata_scheme=req_param.meta_scheme,
82
- base_model=req_param.base_model_name.rsplit('.', maxsplit=1)[:1][0],
83
- base_model_hash='',
84
- prompt=req_param.prompt,
85
- full_prompt=task['positive'],
86
- prompt_expansion=task['expansion'],
87
- negative_prompt=req_param.negative_prompt,
88
- full_negative_prompt=task['negative'],
89
- performance=req_param.performance_selection,
90
- style=str(req_param.style_selections),
91
- refiner_model=req_param.refiner_model_name,
92
- refiner_switch=req_param.refiner_switch,
93
- loras=loras_parser(req_param.loras),
94
- resolution=str(tuple([int(n) for n in req_param.aspect_ratios_selection.split('*')])),
95
- sampler=req_param.advanced_params.sampler_name,
96
- scheduler=req_param.advanced_params.scheduler_name,
97
- seed=str(task['task_seed']),
98
- adm_guidance=str((
99
- req_param.advanced_params.adm_scaler_positive,
100
- req_param.advanced_params.adm_scaler_negative,
101
- req_param.advanced_params.adm_scaler_end)),
102
- guidance_scale=req_param.guidance_scale,
103
- sharpness=req_param.sharpness,
104
- steps=-1,
105
- vae_name=req_param.advanced_params.vae_name,
106
- version=version
107
- )
108
- if meta.metadata_scheme not in ["fooocus", "a111"]:
109
- meta.metadata_scheme = "fooocus"
110
- if meta.metadata_scheme == "fooocus":
111
- meta_dict = meta.model_dump()
112
- for i, lora in enumerate(meta.loras):
113
- attr_name = f"lora_combined_{i+1}"
114
- lr = [str(x) for x in lora]
115
- meta_dict[attr_name] = f"{lr[0]} : {lr[1]}"
116
- else:
117
- meta_dict = meta.model_dump()
118
- return meta_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/common/requests.py DELETED
@@ -1,132 +0,0 @@
1
- """Common model for requests"""
2
- from typing import List
3
- from pydantic import (
4
- BaseModel,
5
- Field,
6
- ValidationError
7
- )
8
-
9
- from modules.config import (
10
- default_sampler,
11
- default_scheduler,
12
- default_prompt,
13
- default_prompt_negative,
14
- default_aspect_ratio,
15
- default_base_model_name,
16
- default_refiner_model_name,
17
- default_refiner_switch,
18
- default_cfg_scale,
19
- default_styles,
20
- default_overwrite_step,
21
- default_inpaint_engine_version,
22
- default_overwrite_switch,
23
- default_cfg_tsnr,
24
- default_sample_sharpness,
25
- default_vae,
26
- default_clip_skip
27
- )
28
-
29
- from modules.flags import clip_skip_max
30
-
31
- from fooocusapi.models.common.base import (
32
- PerformanceSelection,
33
- Lora,
34
- default_loras_model
35
- )
36
-
37
- default_aspect_ratio = default_aspect_ratio.split(" ")[0].replace("×", "*")
38
-
39
-
40
- class QueryJobRequest(BaseModel):
41
- """Query job request"""
42
- job_id: str = Field(description="Job ID to query")
43
- require_step_preview: bool = Field(
44
- default=False,
45
- description="Set to true will return preview image of generation steps at current time")
46
-
47
-
48
- class AdvancedParams(BaseModel):
49
- """Common params object AdvancedParams"""
50
- disable_preview: bool = Field(False, description="Disable preview during generation")
51
- disable_intermediate_results: bool = Field(False, description="Disable intermediate results")
52
- disable_seed_increment: bool = Field(False, description="Disable Seed Increment")
53
- adm_scaler_positive: float = Field(1.5, description="Positive ADM Guidance Scaler", ge=0.1, le=3.0)
54
- adm_scaler_negative: float = Field(0.8, description="Negative ADM Guidance Scaler", ge=0.1, le=3.0)
55
- adm_scaler_end: float = Field(0.3, description="ADM Guidance End At Step", ge=0.0, le=1.0)
56
- adaptive_cfg: float = Field(default_cfg_tsnr, description="CFG Mimicking from TSNR", ge=1.0, le=30.0)
57
- clip_skip: int = Field(default_clip_skip, description="Clip Skip", ge=1, le=clip_skip_max)
58
- sampler_name: str = Field(default_sampler, description="Sampler")
59
- scheduler_name: str = Field(default_scheduler, description="Scheduler")
60
- overwrite_step: int = Field(default_overwrite_step, description="Forced Overwrite of Sampling Step", ge=-1, le=200)
61
- overwrite_switch: float = Field(default_overwrite_switch, description="Forced Overwrite of Refiner Switch Step", ge=-1, le=1)
62
- overwrite_width: int = Field(-1, description="Forced Overwrite of Generating Width", ge=-1, le=2048)
63
- overwrite_height: int = Field(-1, description="Forced Overwrite of Generating Height", ge=-1, le=2048)
64
- overwrite_vary_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Vary"', ge=-1, le=1.0)
65
- overwrite_upscale_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Upscale"', ge=-1, le=1.0)
66
- mixing_image_prompt_and_vary_upscale: bool = Field(False, description="Mixing Image Prompt and Vary/Upscale")
67
- mixing_image_prompt_and_inpaint: bool = Field(False, description="Mixing Image Prompt and Inpaint")
68
- debugging_cn_preprocessor: bool = Field(False, description="Debug Preprocessors")
69
- skipping_cn_preprocessor: bool = Field(False, description="Skip Preprocessors")
70
- canny_low_threshold: int = Field(64, description="Canny Low Threshold", ge=1, le=255)
71
- canny_high_threshold: int = Field(128, description="Canny High Threshold", ge=1, le=255)
72
- refiner_swap_method: str = Field('joint', description="Refiner swap method")
73
- controlnet_softness: float = Field(0.25, description="Softness of ControlNet", ge=0.0, le=1.0)
74
- freeu_enabled: bool = Field(False, description="FreeU enabled")
75
- freeu_b1: float = Field(1.01, description="FreeU B1")
76
- freeu_b2: float = Field(1.02, description="FreeU B2")
77
- freeu_s1: float = Field(0.99, description="FreeU B3")
78
- freeu_s2: float = Field(0.95, description="FreeU B4")
79
- debugging_inpaint_preprocessor: bool = Field(False, description="Debug Inpaint Preprocessing")
80
- inpaint_disable_initial_latent: bool = Field(False, description="Disable initial latent in inpaint")
81
- inpaint_engine: str = Field(default_inpaint_engine_version, description="Inpaint Engine")
82
- inpaint_strength: float = Field(1.0, description="Inpaint Denoising Strength", ge=0.0, le=1.0)
83
- inpaint_respective_field: float = Field(1.0, description="Inpaint Respective Field", ge=0.0, le=1.0)
84
- inpaint_mask_upload_checkbox: bool = Field(False, description="Upload Mask")
85
- invert_mask_checkbox: bool = Field(False, description="Invert Mask")
86
- inpaint_erode_or_dilate: int = Field(0, description="Mask Erode or Dilate", ge=-64, le=64)
87
- black_out_nsfw: bool = Field(False, description="Block out NSFW")
88
- vae_name: str = Field(default_vae, description="VAE name")
89
-
90
-
91
- class CommonRequest(BaseModel):
92
- """All generate request based on this model"""
93
- prompt: str = default_prompt
94
- negative_prompt: str = default_prompt_negative
95
- style_selections: List[str] = default_styles
96
- performance_selection: PerformanceSelection = PerformanceSelection.speed
97
- aspect_ratios_selection: str = default_aspect_ratio
98
- image_number: int = Field(default=1, description="Image number", ge=1, le=32)
99
- image_seed: int = Field(default=-1, description="Seed to generate image, -1 for random")
100
- sharpness: float = Field(default=default_sample_sharpness, ge=0.0, le=30.0)
101
- guidance_scale: float = Field(default=default_cfg_scale, ge=1.0, le=30.0)
102
- base_model_name: str = default_base_model_name
103
- refiner_model_name: str = default_refiner_model_name
104
- refiner_switch: float = Field(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0)
105
- loras: List[Lora] = Field(default=default_loras_model)
106
- advanced_params: AdvancedParams = AdvancedParams()
107
- save_meta: bool = Field(default=True, description="Save meta data")
108
- meta_scheme: str = Field(default='fooocus', description="Meta data scheme, one of [fooocus, a111]")
109
- save_extension: str = Field(default='png', description="Save extension, one of [png, jpg, webp]")
110
- save_name: str = Field(default='', description="Image name for output image, default is job id + seq")
111
- read_wildcards_in_order: bool = Field(default=False, description="Read wildcards in order")
112
- require_base64: bool = Field(default=False, description="Return base64 data of generated image")
113
- async_process: bool = Field(default=False, description="Set to true will run async and return job info for retrieve generation result later")
114
- webhook_url: str | None = Field(default='', description="Optional URL for a webhook callback. If provided, the system will send a POST request to this URL upon task completion or failure."
115
- " This allows for asynchronous notification of task status.")
116
-
117
-
118
- def advanced_params_parser(advanced_params: str | None) -> AdvancedParams:
119
- """
120
- Parse advanced params, Convert to AdvancedParams
121
- Args:
122
- advanced_params: str, json format
123
- Returns:
124
- AdvancedParams object, if validate error return default value
125
- """
126
- if advanced_params is not None and len(advanced_params) > 0:
127
- try:
128
- advanced_params_obj = AdvancedParams.__pydantic_validator__.validate_json(advanced_params)
129
- return AdvancedParams(**advanced_params_obj)
130
- except ValidationError:
131
- return AdvancedParams()
132
- return AdvancedParams()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/common/response.py DELETED
@@ -1,90 +0,0 @@
1
- """Fooocus API models for response"""
2
- from typing import List
3
-
4
- from pydantic import (
5
- BaseModel,
6
- ConfigDict,
7
- Field
8
- )
9
-
10
- from fooocusapi.models.common.task import (
11
- GeneratedImageResult,
12
- AsyncJobStage
13
- )
14
- from fooocusapi.task_queue import TaskType
15
-
16
-
17
- class DescribeImageResponse(BaseModel):
18
- """
19
- describe image response
20
- """
21
- describe: str
22
-
23
-
24
- class AsyncJobResponse(BaseModel):
25
- """
26
- Async job response
27
- Attributes:
28
- job_id: Job ID
29
- job_type: Job type
30
- job_stage: Job stage
31
- job_progress: Job progress, 0-100
32
- job_status: Job status
33
- job_step_preview: Job step preview
34
- job_result: Job result
35
- """
36
- job_id: str = Field(description="Job ID")
37
- job_type: TaskType = Field(description="Job type")
38
- job_stage: AsyncJobStage = Field(description="Job running stage")
39
- job_progress: int = Field(description="Job running progress, 100 is for finished.")
40
- job_status: str | None = Field(None, description="Job running status in text")
41
- job_step_preview: str | None = Field(None, description="Preview image of generation steps at current time, as base64 image")
42
- job_result: List[GeneratedImageResult] | None = Field(None, description="Job generation result")
43
-
44
-
45
- class JobQueueInfo(BaseModel):
46
- """
47
- job queue info
48
- Attributes:
49
- running_size: int, The current running and waiting job count
50
- finished_size: int, The current finished job count
51
- last_job_id: str, Last submit generation job id
52
- """
53
- running_size: int = Field(description="The current running and waiting job count")
54
- finished_size: int = Field(description="Finished job count (after auto clean)")
55
- last_job_id: str | None = Field(description="Last submit generation job id")
56
-
57
-
58
- # TODO May need more detail fields, will add later when someone need
59
- class JobHistoryInfo(BaseModel):
60
- """
61
- job history info
62
- """
63
- job_id: str
64
- is_finished: bool = False
65
-
66
-
67
- # Response model for the historical tasks
68
- class JobHistoryResponse(BaseModel):
69
- """
70
- job history response
71
- """
72
- queue: List[JobHistoryInfo] = []
73
- history: List[JobHistoryInfo] = []
74
-
75
-
76
- class AllModelNamesResponse(BaseModel):
77
- """
78
- all model list response
79
- """
80
- model_filenames: List[str] = Field(description="All available model filenames")
81
- lora_filenames: List[str] = Field(description="All available lora filenames")
82
-
83
- model_config = ConfigDict(
84
- protected_namespaces=('protect_me_', 'also_protect_')
85
- )
86
-
87
-
88
- class StopResponse(BaseModel):
89
- """stop task response"""
90
- msg: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/common/task.py DELETED
@@ -1,60 +0,0 @@
1
- """
2
- Task and job related models
3
- """
4
- from enum import Enum
5
- from pydantic import (
6
- BaseModel,
7
- Field
8
- )
9
-
10
-
11
- class TaskType(str, Enum):
12
- """
13
- Task type object
14
- """
15
- text_2_img = 'Text to Image'
16
- img_uov = 'Image Upscale or Variation'
17
- img_inpaint_outpaint = 'Image Inpaint or Outpaint'
18
- img_prompt = 'Image Prompt'
19
- not_found = 'Not Found'
20
-
21
-
22
- class GenerationFinishReason(str, Enum):
23
- """
24
- Generation finish reason
25
- """
26
- success = 'SUCCESS'
27
- queue_is_full = 'QUEUE_IS_FULL'
28
- user_cancel = 'USER_CANCEL'
29
- error = 'ERROR'
30
-
31
-
32
- class ImageGenerationResult:
33
- """
34
- Image generation result
35
- """
36
- def __init__(self, im: str | None, seed: str, finish_reason: GenerationFinishReason):
37
- self.im = im
38
- self.seed = seed
39
- self.finish_reason = finish_reason
40
-
41
-
42
- class AsyncJobStage(str, Enum):
43
- """
44
- Async job stage
45
- """
46
- waiting = 'WAITING'
47
- running = 'RUNNING'
48
- success = 'SUCCESS'
49
- error = 'ERROR'
50
-
51
-
52
- class GeneratedImageResult(BaseModel):
53
- """
54
- Generated images result
55
- """
56
- base64: str | None = Field(
57
- description="Image encoded in base64, or null if finishReason is not 'SUCCESS', only return when request require base64")
58
- url: str | None = Field(description="Image file static serve url, or null if finishReason is not 'SUCCESS'")
59
- seed: str = Field(description="The seed associated with this image")
60
- finish_reason: GenerationFinishReason
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/requests_v1.py DELETED
@@ -1,274 +0,0 @@
1
- """
2
- requests models for v1 endpoints
3
- """
4
- from typing import List
5
- from fastapi.params import File
6
- from fastapi import (
7
- UploadFile,
8
- Form
9
- )
10
- from fooocusapi.models.common.requests import (
11
- CommonRequest,
12
- advanced_params_parser
13
- )
14
- from fooocusapi.models.common.base import (
15
- ImagePrompt,
16
- ControlNetType,
17
- OutpaintExpansion,
18
- UpscaleOrVaryMethod,
19
- PerformanceSelection
20
- )
21
-
22
- from fooocusapi.models.common.base import (
23
- style_selection_parser,
24
- lora_parser,
25
- outpaint_selections_parser,
26
- image_prompt_parser,
27
- default_loras_json
28
- )
29
-
30
- from fooocusapi.configs.default import (
31
- default_prompt_negative,
32
- default_aspect_ratio,
33
- default_base_model_name,
34
- default_refiner_model_name,
35
- default_refiner_switch,
36
- default_cfg_scale,
37
- default_styles,
38
- )
39
-
40
-
41
- class ImgUpscaleOrVaryRequest(CommonRequest):
42
- """
43
- Request for image upscale or variation
44
- Attributes:
45
- input_image: Input image
46
- uov_method: Upscale or variation method
47
- upscale_value: upscale value
48
- Functions:
49
- as_form: Convert request to form data
50
- """
51
- input_image: UploadFile
52
- uov_method: UpscaleOrVaryMethod
53
- upscale_value: float | None
54
-
55
- @classmethod
56
- def as_form(
57
- cls,
58
- input_image: UploadFile = Form(description="Init image for upscale or outpaint"),
59
- uov_method: UpscaleOrVaryMethod = Form(),
60
- upscale_value: float | None = Form(None, description="Upscale custom value, None for default value", ge=1.0, le=5.0),
61
- prompt: str = Form(''),
62
- negative_prompt: str = Form(default_prompt_negative),
63
- style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
64
- performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
65
- aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
66
- image_number: int = Form(default=1, description="Image number", ge=1, le=32),
67
- image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
68
- sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
69
- guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
70
- base_model_name: str = Form(default_base_model_name, description="checkpoint file name"),
71
- refiner_model_name: str = Form(default_refiner_model_name, description="refiner file name"),
72
- refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
73
- loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
74
- advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
75
- save_meta: bool = Form(default=False, description="Save metadata to image"),
76
- meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
77
- save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
78
- save_name: str = Form(default="", description="Save name, empty for auto generate"),
79
- require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
80
- read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
81
- async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
82
- webhook_url: str = Form(default="", description="Webhook url for generation result"),
83
- ):
84
- style_selection_arr = style_selection_parser(style_selections)
85
- loras_model = lora_parser(loras)
86
- advanced_params_obj = advanced_params_parser(advanced_params)
87
-
88
- return cls(
89
- input_image=input_image, uov_method=uov_method, upscale_value=upscale_value,
90
- prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
91
- performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
92
- image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
93
- base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
94
- loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
95
- save_extension=save_extension, save_name=save_name, require_base64=require_base64,
96
- read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
97
-
98
-
99
- class ImgInpaintOrOutpaintRequest(CommonRequest):
100
- """
101
- Image Inpaint or Outpaint Request
102
- """
103
- input_image: UploadFile | None
104
- input_mask: UploadFile | None
105
- inpaint_additional_prompt: str | None
106
- outpaint_selections: List[OutpaintExpansion]
107
- outpaint_distance_left: int
108
- outpaint_distance_right: int
109
- outpaint_distance_top: int
110
- outpaint_distance_bottom: int
111
-
112
- @classmethod
113
- def as_form(
114
- cls,
115
- input_image: UploadFile = Form(description="Init image for inpaint or outpaint"),
116
- input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
117
- inpaint_additional_prompt: str | None = Form("", description="Describe what you want to inpaint"),
118
- outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
119
- outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, -1 for default"),
120
- outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, -1 for default"),
121
- outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, -1 for default"),
122
- outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, -1 for default"),
123
- prompt: str = Form(''),
124
- negative_prompt: str = Form(default_prompt_negative),
125
- style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
126
- performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
127
- aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
128
- image_number: int = Form(default=1, description="Image number", ge=1, le=32),
129
- image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
130
- sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
131
- guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
132
- base_model_name: str = Form(default_base_model_name),
133
- refiner_model_name: str = Form(default_refiner_model_name),
134
- refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
135
- loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
136
- advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
137
- save_meta: bool = Form(default=False, description="Save metadata to image"),
138
- meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
139
- save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
140
- save_name: str = Form(default="", description="Save name, empty for auto generate"),
141
- require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
142
- read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
143
- async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
144
- webhook_url: str = Form(default="", description="Webhook url for generation result"),
145
- ):
146
- if isinstance(input_mask, File):
147
- input_mask = None
148
-
149
- outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
150
- style_selection_arr = style_selection_parser(style_selections)
151
- loras_model = lora_parser(loras)
152
- advanced_params_obj = advanced_params_parser(advanced_params)
153
-
154
- return cls(
155
- input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt,
156
- outpaint_selections=outpaint_selections_arr, outpaint_distance_left=outpaint_distance_left,
157
- outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top,
158
- outpaint_distance_bottom=outpaint_distance_bottom, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
159
- performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
160
- image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
161
- base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
162
- loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
163
- save_extension=save_extension, save_name=save_name, require_base64=require_base64,
164
- read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
165
-
166
-
167
- class ImgPromptRequest(ImgInpaintOrOutpaintRequest):
168
- """
169
- Image Prompt Request
170
- """
171
- image_prompts: List[ImagePrompt]
172
-
173
- @classmethod
174
- def as_form(
175
- cls,
176
- input_image: UploadFile = Form(File(None), description="Init image for inpaint or outpaint"),
177
- input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
178
- inpaint_additional_prompt: str | None = Form(None, description="Describe what you want to inpaint"),
179
- outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
180
- outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, 0 for default"),
181
- outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, 0 for default"),
182
- outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, 0 for default"),
183
- outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, 0 for default"),
184
- cn_img1: UploadFile = Form(File(None), description="Input image for image prompt"),
185
- cn_stop1: float | None = Form(
186
- default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
187
- cn_weight1: float | None = Form(
188
- default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
189
- cn_type1: ControlNetType = Form(
190
- default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
191
- cn_img2: UploadFile = Form(
192
- File(None), description="Input image for image prompt"),
193
- cn_stop2: float | None = Form(
194
- default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
195
- cn_weight2: float | None = Form(
196
- default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
197
- cn_type2: ControlNetType = Form(
198
- default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
199
- cn_img3: UploadFile = Form(
200
- File(None), description="Input image for image prompt"),
201
- cn_stop3: float | None = Form(
202
- default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
203
- cn_weight3: float | None = Form(
204
- default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
205
- cn_type3: ControlNetType = Form(
206
- default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
207
- cn_img4: UploadFile = Form(
208
- File(None), description="Input image for image prompt"),
209
- cn_stop4: float | None = Form(
210
- default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
211
- cn_weight4: float | None = Form(
212
- default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
213
- cn_type4: ControlNetType = Form(
214
- default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
215
- prompt: str = Form(''),
216
- negative_prompt: str = Form(default_prompt_negative),
217
- style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
218
- performance_selection: PerformanceSelection = Form(
219
- PerformanceSelection.speed),
220
- aspect_ratios_selection: str = Form(default_aspect_ratio),
221
- image_number: int = Form(
222
- default=1, description="Image number", ge=1, le=32),
223
- image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
224
- sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
225
- guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
226
- base_model_name: str = Form(default_base_model_name),
227
- refiner_model_name: str = Form(default_refiner_model_name),
228
- refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
229
- loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
230
- advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
231
- save_meta: bool = Form(default=False, description="Save metadata to image"),
232
- meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
233
- save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
234
- save_name: str = Form(default="", description="Save name, empty for auto generate"),
235
- require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
236
- read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
237
- async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
238
- webhook_url: str = Form(default="", description="Webhook url for generation result"),
239
- ):
240
- if isinstance(input_image, File):
241
- input_image = None
242
- if isinstance(input_mask, File):
243
- input_mask = None
244
- if isinstance(cn_img1, File):
245
- cn_img1 = None
246
- if isinstance(cn_img2, File):
247
- cn_img2 = None
248
- if isinstance(cn_img3, File):
249
- cn_img3 = None
250
- if isinstance(cn_img4, File):
251
- cn_img4 = None
252
-
253
- outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
254
-
255
- image_prompt_config = [
256
- (cn_img1, cn_stop1, cn_weight1, cn_type1),
257
- (cn_img2, cn_stop2, cn_weight2, cn_type2),
258
- (cn_img3, cn_stop3, cn_weight3, cn_type3),
259
- (cn_img4, cn_stop4, cn_weight4, cn_type4)]
260
- image_prompts = image_prompt_parser(image_prompt_config)
261
- style_selection_arr = style_selection_parser(style_selections)
262
- loras_model = lora_parser(loras)
263
- advanced_params_obj = advanced_params_parser(advanced_params)
264
-
265
- return cls(
266
- input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt, outpaint_selections=outpaint_selections_arr,
267
- outpaint_distance_left=outpaint_distance_left, outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top, outpaint_distance_bottom=outpaint_distance_bottom,
268
- image_prompts=image_prompts, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
269
- performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
270
- image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
271
- base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
272
- loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
273
- save_extension=save_extension, save_name=save_name, require_base64=require_base64,
274
- read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/models/requests_v2.py DELETED
@@ -1,50 +0,0 @@
1
- """V2 API models"""
2
- from typing import List
3
- from pydantic import BaseModel, Field
4
- from fooocusapi.models.common.requests import CommonRequest
5
- from fooocusapi.models.common.base import (
6
- ControlNetType,
7
- OutpaintExpansion,
8
- ImagePrompt,
9
- UpscaleOrVaryMethod
10
- )
11
-
12
-
13
- class ImagePromptJson(BaseModel):
14
- """Image prompt for V2 API"""
15
- cn_img: str | None = Field(None, description="Input image for image prompt as base64")
16
- cn_stop: float | None = Field(0, ge=0, le=1, description="Stop at for image prompt, 0 for default value")
17
- cn_weight: float | None = Field(0, ge=0, le=2, description="Weight for image prompt, 0 for default value")
18
- cn_type: ControlNetType = Field(default=ControlNetType.cn_ip, description="ControlNet type for image prompt")
19
-
20
-
21
- class ImgInpaintOrOutpaintRequestJson(CommonRequest):
22
- """image inpaint or outpaint request"""
23
- input_image: str = Field('', description="Init image for inpaint or outpaint as base64")
24
- input_mask: str | None = Field('', description="Inpaint or outpaint mask as base64")
25
- inpaint_additional_prompt: str | None = Field('', description="Describe what you want to inpaint")
26
- outpaint_selections: List[OutpaintExpansion] = []
27
- outpaint_distance_left: int | None = Field(-1, description="Set outpaint left distance")
28
- outpaint_distance_right: int | None = Field(-1, description="Set outpaint right distance")
29
- outpaint_distance_top: int | None = Field(-1, description="Set outpaint top distance")
30
- outpaint_distance_bottom: int | None = Field(-1, description="Set outpaint bottom distance")
31
- image_prompts: List[ImagePromptJson | ImagePrompt] = []
32
-
33
-
34
- class ImgPromptRequestJson(ImgInpaintOrOutpaintRequestJson):
35
- """img prompt request json"""
36
- input_image: str | None = Field(None, description="Init image for inpaint or outpaint as base64")
37
- image_prompts: List[ImagePromptJson | ImagePrompt]
38
-
39
-
40
- class Text2ImgRequestWithPrompt(CommonRequest):
41
- """text to image request with prompt"""
42
- image_prompts: List[ImagePromptJson] = []
43
-
44
-
45
- class ImgUpscaleOrVaryRequestJson(CommonRequest):
46
- """img upscale or vary request json"""
47
- uov_method: UpscaleOrVaryMethod = UpscaleOrVaryMethod.upscale_2
48
- upscale_value: float | None = Field(1.0, ge=1.0, le=5.0, description="Upscale custom value, 1.0 for default value")
49
- input_image: str = Field(description="Init image for upscale or outpaint as base64")
50
- image_prompts: List[ImagePromptJson | ImagePrompt] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/parameters.py DELETED
@@ -1,94 +0,0 @@
1
- from typing import Dict, List, Tuple
2
- import numpy as np
3
- import copy
4
-
5
- from fooocusapi.models.common.requests import AdvancedParams
6
-
7
-
8
- class ImageGenerationParams:
9
- def __init__(
10
- self,
11
- prompt: str,
12
- negative_prompt: str,
13
- style_selections: List[str],
14
- performance_selection: str,
15
- aspect_ratios_selection: str,
16
- image_number: int,
17
- image_seed: int | None,
18
- sharpness: float,
19
- guidance_scale: float,
20
- base_model_name: str,
21
- refiner_model_name: str,
22
- refiner_switch: float,
23
- loras: List[Tuple[str, float]],
24
- uov_input_image: np.ndarray | None,
25
- uov_method: str,
26
- upscale_value: float | None,
27
- outpaint_selections: List[str],
28
- outpaint_distance_left: int,
29
- outpaint_distance_right: int,
30
- outpaint_distance_top: int,
31
- outpaint_distance_bottom: int,
32
- inpaint_input_image: Dict[str, np.ndarray] | None,
33
- inpaint_additional_prompt: str | None,
34
- image_prompts: List[Tuple[np.ndarray, float, float, str]],
35
- advanced_params: List[any] | None,
36
- save_extension: str,
37
- save_meta: bool,
38
- meta_scheme: str,
39
- save_name: str,
40
- require_base64: bool,
41
- ):
42
- self.prompt = prompt
43
- self.negative_prompt = negative_prompt
44
- self.style_selections = style_selections
45
- self.performance_selection = performance_selection
46
- self.aspect_ratios_selection = aspect_ratios_selection
47
- self.image_number = image_number
48
- self.image_seed = image_seed
49
- self.sharpness = sharpness
50
- self.guidance_scale = guidance_scale
51
- self.base_model_name = base_model_name
52
- self.refiner_model_name = refiner_model_name
53
- self.refiner_switch = refiner_switch
54
- self.loras = loras
55
- self.uov_input_image = uov_input_image
56
- self.uov_method = uov_method
57
- self.upscale_value = upscale_value
58
- self.outpaint_selections = outpaint_selections
59
- self.outpaint_distance_left = outpaint_distance_left
60
- self.outpaint_distance_right = outpaint_distance_right
61
- self.outpaint_distance_top = outpaint_distance_top
62
- self.outpaint_distance_bottom = outpaint_distance_bottom
63
- self.inpaint_input_image = inpaint_input_image
64
- self.inpaint_additional_prompt = inpaint_additional_prompt
65
- self.image_prompts = image_prompts
66
- self.save_extension = save_extension
67
- self.save_meta = save_meta
68
- self.meta_scheme = meta_scheme
69
- self.save_name = save_name
70
- self.require_base64 = require_base64
71
- self.advanced_params = advanced_params
72
-
73
- if self.advanced_params is None:
74
- self.advanced_params = AdvancedParams()
75
-
76
- # Auto set mixing_image_prompt_and_inpaint to True
77
- if len(self.image_prompts) > 0 and self.inpaint_input_image is not None:
78
- print("Mixing Image Prompts and Inpaint Enabled")
79
- self.advanced_params.mixing_image_prompt_and_inpaint = True
80
- if len(self.image_prompts) > 0 and self.uov_input_image is not None:
81
- print("Mixing Image Prompts and Vary Upscale Enabled")
82
- self.advanced_params.mixing_image_prompt_and_vary_upscale = True
83
-
84
- def to_dict(self):
85
- """
86
- Convert the ImageGenerationParams object to a dictionary.
87
- Args:
88
- self:
89
-
90
- Returns:
91
- self to dict
92
- """
93
- obj_dict = copy.deepcopy(self)
94
- return obj_dict.__dict__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/routes/__init__.py DELETED
File without changes
fooocusapi/routes/generate_v1.py DELETED
@@ -1,186 +0,0 @@
1
- """Generate API V1 routes
2
-
3
- """
4
- from typing import List, Optional
5
- from fastapi import APIRouter, Depends, Header, Query, UploadFile
6
- from fastapi.params import File
7
-
8
- from modules.util import HWC3
9
-
10
- from fooocusapi.models.common.base import DescribeImageType
11
- from fooocusapi.utils.api_utils import api_key_auth
12
-
13
- from fooocusapi.models.common.requests import CommonRequest as Text2ImgRequest
14
- from fooocusapi.models.requests_v1 import (
15
- ImgUpscaleOrVaryRequest,
16
- ImgPromptRequest,
17
- ImgInpaintOrOutpaintRequest
18
- )
19
- from fooocusapi.models.common.response import (
20
- AsyncJobResponse,
21
- GeneratedImageResult,
22
- DescribeImageResponse,
23
- StopResponse
24
- )
25
- from fooocusapi.utils.call_worker import call_worker
26
- from fooocusapi.utils.img_utils import read_input_image
27
- from fooocusapi.configs.default import img_generate_responses
28
- from fooocusapi.worker import process_stop
29
-
30
-
31
- secure_router = APIRouter(
32
- dependencies=[Depends(api_key_auth)]
33
- )
34
-
35
-
36
- def stop_worker():
37
- """Interrupt worker process"""
38
- process_stop()
39
-
40
-
41
- @secure_router.post(
42
- path="/v1/generation/text-to-image",
43
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
44
- responses=img_generate_responses,
45
- tags=["GenerateV1"])
46
- def text2img_generation(
47
- req: Text2ImgRequest,
48
- accept: str = Header(None),
49
- accept_query: str | None = Query(
50
- None, alias='accept',
51
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
52
- """\nText to Image Generation\n
53
- A text to image generation endpoint
54
- Arguments:
55
- req {Text2ImgRequest} -- Text to image generation request
56
- accept {str} -- Accept header
57
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
58
- returns:
59
- Response -- img_generate_responses
60
- """
61
- if accept_query is not None and len(accept_query) > 0:
62
- accept = accept_query
63
-
64
- return call_worker(req, accept)
65
-
66
-
67
- @secure_router.post(
68
- path="/v1/generation/image-upscale-vary",
69
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
70
- responses=img_generate_responses,
71
- tags=["GenerateV1"])
72
- def img_upscale_or_vary(
73
- input_image: UploadFile,
74
- req: ImgUpscaleOrVaryRequest = Depends(ImgUpscaleOrVaryRequest.as_form),
75
- accept: str = Header(None),
76
- accept_query: str | None = Query(
77
- None, alias='accept',
78
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
79
- """\nImage upscale or vary\n
80
- Image upscale or vary
81
- Arguments:
82
- input_image {UploadFile} -- Input image file
83
- req {ImgUpscaleOrVaryRequest} -- Request body
84
- accept {str} -- Accept header
85
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
86
- Returns:
87
- Response -- img_generate_responses
88
- """
89
- if accept_query is not None and len(accept_query) > 0:
90
- accept = accept_query
91
-
92
- return call_worker(req, accept)
93
-
94
-
95
- @secure_router.post(
96
- path="/v1/generation/image-inpaint-outpaint",
97
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
98
- responses=img_generate_responses,
99
- tags=["GenerateV1"])
100
- def img_inpaint_or_outpaint(
101
- input_image: UploadFile,
102
- req: ImgInpaintOrOutpaintRequest = Depends(ImgInpaintOrOutpaintRequest.as_form),
103
- accept: str = Header(None),
104
- accept_query: str | None = Query(
105
- None, alias='accept',
106
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
107
- """\nInpaint or outpaint\n
108
- Inpaint or outpaint
109
- Arguments:
110
- input_image {UploadFile} -- Input image file
111
- req {ImgInpaintOrOutpaintRequest} -- Request body
112
- accept {str} -- Accept header
113
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
114
- """
115
- if accept_query is not None and len(accept_query) > 0:
116
- accept = accept_query
117
-
118
- return call_worker(req, accept)
119
-
120
-
121
- @secure_router.post(
122
- path="/v1/generation/image-prompt",
123
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
124
- responses=img_generate_responses,
125
- tags=["GenerateV1"])
126
- def img_prompt(
127
- cn_img1: Optional[UploadFile] = File(None),
128
- req: ImgPromptRequest = Depends(ImgPromptRequest.as_form),
129
- accept: str = Header(None),
130
- accept_query: str | None = Query(
131
- None, alias='accept',
132
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
133
- """\nImage Prompt\n
134
- Image Prompt
135
- A prompt-based image generation.
136
- Arguments:
137
- cn_img1 {UploadFile} -- Input image file
138
- req {ImgPromptRequest} -- Request body
139
- accept {str} -- Accept header
140
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
141
- Returns:
142
- Response -- img_generate_responses
143
- """
144
- if accept_query is not None and len(accept_query) > 0:
145
- accept = accept_query
146
-
147
- return call_worker(req, accept)
148
-
149
-
150
- @secure_router.post(
151
- path="/v1/tools/describe-image",
152
- response_model=DescribeImageResponse,
153
- tags=["GenerateV1"])
154
- def describe_image(
155
- image: UploadFile,
156
- image_type: DescribeImageType = Query(
157
- DescribeImageType.photo,
158
- description="Image type, 'Photo' or 'Anime'")):
159
- """\nDescribe image\n
160
- Describe image, Get tags from an image
161
- Arguments:
162
- image {UploadFile} -- Image to get tags
163
- image_type {DescribeImageType} -- Image type, 'Photo' or 'Anime'
164
- Returns:
165
- DescribeImageResponse -- Describe image response, a string
166
- """
167
- if image_type == DescribeImageType.photo:
168
- from extras.interrogate import default_interrogator as default_interrogator_photo
169
- interrogator = default_interrogator_photo
170
- else:
171
- from extras.wd14tagger import default_interrogator as default_interrogator_anime
172
- interrogator = default_interrogator_anime
173
- img = HWC3(read_input_image(image))
174
- result = interrogator(img)
175
- return DescribeImageResponse(describe=result)
176
-
177
-
178
- @secure_router.post(
179
- path="/v1/generation/stop",
180
- response_model=StopResponse,
181
- description="Job stopping",
182
- tags=["Default"])
183
- def stop():
184
- """Interrupt worker"""
185
- stop_worker()
186
- return StopResponse(msg="success")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/routes/generate_v2.py DELETED
@@ -1,199 +0,0 @@
1
- """Generate API V2 routes
2
-
3
- """
4
- from typing import List
5
- from fastapi import APIRouter, Depends, Header, Query
6
-
7
- from fooocusapi.utils.api_utils import api_key_auth
8
- from fooocusapi.models.requests_v1 import ImagePrompt
9
- from fooocusapi.models.requests_v2 import (
10
- ImgInpaintOrOutpaintRequestJson,
11
- ImgPromptRequestJson,
12
- Text2ImgRequestWithPrompt,
13
- ImgUpscaleOrVaryRequestJson
14
- )
15
- from fooocusapi.models.common.response import (
16
- AsyncJobResponse,
17
- GeneratedImageResult
18
- )
19
- from fooocusapi.utils.call_worker import call_worker
20
- from fooocusapi.utils.img_utils import base64_to_stream
21
- from fooocusapi.configs.default import img_generate_responses
22
-
23
-
24
- secure_router = APIRouter(
25
- dependencies=[Depends(api_key_auth)]
26
- )
27
-
28
-
29
- @secure_router.post(
30
- path="/v2/generation/text-to-image-with-ip",
31
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
32
- responses=img_generate_responses,
33
- tags=["GenerateV2"])
34
- def text_to_img_with_ip(
35
- req: Text2ImgRequestWithPrompt,
36
- accept: str = Header(None),
37
- accept_query: str | None = Query(
38
- default=None, alias='accept',
39
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
40
- """\nText to image with prompt\n
41
- Text to image with prompt
42
- Arguments:
43
- req {Text2ImgRequestWithPrompt} -- Text to image generation request
44
- accept {str} -- Accept header
45
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
46
- Returns:
47
- Response -- img_generate_responses
48
- """
49
- if accept_query is not None and len(accept_query) > 0:
50
- accept = accept_query
51
-
52
- default_image_prompt = ImagePrompt(cn_img=None)
53
- image_prompts_files: List[ImagePrompt] = []
54
- for image_prompt in req.image_prompts:
55
- image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
56
- image = ImagePrompt(
57
- cn_img=image_prompt.cn_img,
58
- cn_stop=image_prompt.cn_stop,
59
- cn_weight=image_prompt.cn_weight,
60
- cn_type=image_prompt.cn_type)
61
- image_prompts_files.append(image)
62
-
63
- while len(image_prompts_files) <= 4:
64
- image_prompts_files.append(default_image_prompt)
65
-
66
- req.image_prompts = image_prompts_files
67
-
68
- return call_worker(req, accept)
69
-
70
-
71
- @secure_router.post(
72
- path="/v2/generation/image-upscale-vary",
73
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
74
- responses=img_generate_responses,
75
- tags=["GenerateV2"])
76
- def img_upscale_or_vary(
77
- req: ImgUpscaleOrVaryRequestJson,
78
- accept: str = Header(None),
79
- accept_query: str | None = Query(
80
- None, alias='accept', description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
81
- """\nImage upscale or vary\n
82
- Image upscale or vary
83
- Arguments:
84
- req {ImgUpscaleOrVaryRequestJson} -- Image upscale or vary request
85
- accept {str} -- Accept header
86
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
87
- Returns:
88
- Response -- img_generate_responses
89
- """
90
- if accept_query is not None and len(accept_query) > 0:
91
- accept = accept_query
92
-
93
- req.input_image = base64_to_stream(req.input_image)
94
-
95
- default_image_prompt = ImagePrompt(cn_img=None)
96
- image_prompts_files: List[ImagePrompt] = []
97
- for image_prompt in req.image_prompts:
98
- image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
99
- image = ImagePrompt(
100
- cn_img=image_prompt.cn_img,
101
- cn_stop=image_prompt.cn_stop,
102
- cn_weight=image_prompt.cn_weight,
103
- cn_type=image_prompt.cn_type)
104
- image_prompts_files.append(image)
105
- while len(image_prompts_files) <= 4:
106
- image_prompts_files.append(default_image_prompt)
107
- req.image_prompts = image_prompts_files
108
-
109
- return call_worker(req, accept)
110
-
111
-
112
- @secure_router.post(
113
- path="/v2/generation/image-inpaint-outpaint",
114
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
115
- responses=img_generate_responses,
116
- tags=["GenerateV2"])
117
- def img_inpaint_or_outpaint(
118
- req: ImgInpaintOrOutpaintRequestJson,
119
- accept: str = Header(None),
120
- accept_query: str | None = Query(
121
- None, alias='accept',
122
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
123
- """\nInpaint or outpaint\n
124
- Inpaint or outpaint
125
- Arguments:
126
- req {ImgInpaintOrOutpaintRequestJson} -- Request body
127
- accept {str} -- Accept header
128
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
129
- Returns:
130
- Response -- img_generate_responses
131
- """
132
- if accept_query is not None and len(accept_query) > 0:
133
- accept = accept_query
134
-
135
- req.input_image = base64_to_stream(req.input_image)
136
- if req.input_mask is not None:
137
- req.input_mask = base64_to_stream(req.input_mask)
138
- default_image_prompt = ImagePrompt(cn_img=None)
139
- image_prompts_files: List[ImagePrompt] = []
140
- for image_prompt in req.image_prompts:
141
- image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
142
- image = ImagePrompt(
143
- cn_img=image_prompt.cn_img,
144
- cn_stop=image_prompt.cn_stop,
145
- cn_weight=image_prompt.cn_weight,
146
- cn_type=image_prompt.cn_type)
147
- image_prompts_files.append(image)
148
- while len(image_prompts_files) <= 4:
149
- image_prompts_files.append(default_image_prompt)
150
- req.image_prompts = image_prompts_files
151
-
152
- return call_worker(req, accept)
153
-
154
-
155
- @secure_router.post(
156
- path="/v2/generation/image-prompt",
157
- response_model=List[GeneratedImageResult] | AsyncJobResponse,
158
- responses=img_generate_responses,
159
- tags=["GenerateV2"])
160
- def img_prompt(
161
- req: ImgPromptRequestJson,
162
- accept: str = Header(None),
163
- accept_query: str | None = Query(
164
- None, alias='accept',
165
- description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
166
- """\nImage prompt\n
167
- Image prompt generation
168
- Arguments:
169
- req {ImgPromptRequest} -- Request body
170
- accept {str} -- Accept header
171
- accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
172
- Returns:
173
- Response -- img_generate_responses
174
- """
175
- if accept_query is not None and len(accept_query) > 0:
176
- accept = accept_query
177
-
178
- if req.input_image is not None:
179
- req.input_image = base64_to_stream(req.input_image)
180
- if req.input_mask is not None:
181
- req.input_mask = base64_to_stream(req.input_mask)
182
-
183
- default_image_prompt = ImagePrompt(cn_img=None)
184
- image_prompts_files: List[ImagePrompt] = []
185
- for image_prompt in req.image_prompts:
186
- image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
187
- image = ImagePrompt(
188
- cn_img=image_prompt.cn_img,
189
- cn_stop=image_prompt.cn_stop,
190
- cn_weight=image_prompt.cn_weight,
191
- cn_type=image_prompt.cn_type)
192
- image_prompts_files.append(image)
193
-
194
- while len(image_prompts_files) <= 4:
195
- image_prompts_files.append(default_image_prompt)
196
-
197
- req.image_prompts = image_prompts_files
198
-
199
- return call_worker(req, accept)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/routes/query.py DELETED
@@ -1,135 +0,0 @@
1
- """Query API"""
2
- from typing import List
3
- from fastapi import Depends, Response, APIRouter
4
-
5
- from fooocusapi.args import args
6
-
7
- from fooocusapi.models.common.requests import QueryJobRequest
8
- from fooocusapi.models.common.response import (
9
- AsyncJobResponse,
10
- JobHistoryInfo,
11
- JobQueueInfo,
12
- JobHistoryResponse,
13
- AllModelNamesResponse
14
- )
15
- from fooocusapi.models.common.task import AsyncJobStage
16
-
17
- from fooocusapi.utils.api_utils import generate_async_output, api_key_auth
18
- from fooocusapi.task_queue import TaskType
19
- from fooocusapi.worker import worker_queue
20
-
21
- secure_router = APIRouter(dependencies=[Depends(api_key_auth)])
22
-
23
-
24
- @secure_router.get(path="/", tags=['Query'])
25
- def home():
26
- """Home page"""
27
- return Response(
28
- content='Swagger-UI to: <a href="/docs">/docs</a>',
29
- media_type="text/html"
30
- )
31
-
32
-
33
- @secure_router.get(
34
- path="/ping",
35
- description="Returns a simple 'pong'",
36
- tags=['Query'])
37
- async def ping():
38
- """\nPing\n
39
- Ping page, just to check if the fastapi is up.
40
- Instant return correct, does not mean the service is available.
41
- Returns:
42
- A simple string pong
43
- """
44
- return 'pong'
45
-
46
-
47
- @secure_router.get(
48
- path="/v1/generation/query-job",
49
- response_model=AsyncJobResponse,
50
- description="Query async generation job",
51
- tags=['Query'])
52
- def query_job(req: QueryJobRequest = Depends()):
53
- """query job info by id"""
54
- queue_task = worker_queue.get_task(req.job_id, True)
55
- if queue_task is None:
56
- result = AsyncJobResponse(
57
- job_id="",
58
- job_type=TaskType.not_found,
59
- job_stage=AsyncJobStage.error,
60
- job_progress=0,
61
- job_status="Job not found")
62
- content = result.model_dump_json()
63
- return Response(content=content, media_type='application/json', status_code=404)
64
- return generate_async_output(queue_task, req.require_step_preview)
65
-
66
-
67
- @secure_router.get(
68
- path="/v1/generation/job-queue",
69
- response_model=JobQueueInfo,
70
- description="Query job queue info",
71
- tags=['Query'])
72
- def job_queue():
73
- """Query job queue info"""
74
- queue = JobQueueInfo(
75
- running_size=len(worker_queue.queue),
76
- finished_size=len(worker_queue.history),
77
- last_job_id=worker_queue.last_job_id
78
- )
79
- return queue
80
-
81
-
82
- @secure_router.get(
83
- path="/v1/generation/job-history",
84
- response_model=JobHistoryResponse | dict,
85
- description="Query historical job data",
86
- tags=["Query"])
87
- def get_history(job_id: str = None, page: int = 0, page_size: int = 20):
88
- """Fetch and return the historical tasks"""
89
- queue = [
90
- JobHistoryInfo(
91
- job_id=item.job_id,
92
- is_finished=item.is_finished
93
- ) for item in worker_queue.queue
94
- ]
95
- if not args.persistent:
96
- history = [
97
- JobHistoryInfo(
98
- job_id=item.job_id,
99
- is_finished=item.is_finished
100
- ) for item in worker_queue.history
101
- ]
102
- return JobHistoryResponse(history=history, queue=queue)
103
-
104
- from fooocusapi.sql_client import query_history
105
- history = query_history(task_id=job_id, page=page, page_size=page_size)
106
- return {
107
- "history": history,
108
- "queue": queue
109
- }
110
-
111
-
112
- @secure_router.get(
113
- path="/v1/engines/all-models",
114
- response_model=AllModelNamesResponse,
115
- description="Get all filenames of base model and lora",
116
- tags=["Query"])
117
- def all_models():
118
- """Refresh and return all models"""
119
- from modules import config
120
- config.update_files()
121
- models = AllModelNamesResponse(
122
- model_filenames=config.model_filenames,
123
- lora_filenames=config.lora_filenames)
124
- return models
125
-
126
-
127
- @secure_router.get(
128
- path="/v1/engines/styles",
129
- response_model=List[str],
130
- description="Get all legal Fooocus styles",
131
- tags=['Query'])
132
- def all_styles():
133
- """Return all available styles"""
134
- from modules.sdxl_styles import legal_style_names
135
- return legal_style_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/sql_client.py DELETED
@@ -1,269 +0,0 @@
1
- """
2
- SQLite client for Fooocus API
3
- """
4
- import os
5
- import time
6
- import platform
7
- from datetime import datetime
8
- from typing import Optional
9
- import copy
10
-
11
- from sqlalchemy import Integer, Float, VARCHAR, Boolean, JSON, Text, create_engine
12
- from sqlalchemy.orm import declarative_base, Session, Mapped, mapped_column
13
-
14
-
15
- Base = declarative_base()
16
-
17
-
18
- if platform.system().lower() == "windows":
19
- default_sqlite_db_path = os.path.join(
20
- os.path.dirname(__file__), "../database.db"
21
- ).replace("\\", "/")
22
- else:
23
- default_sqlite_db_path = os.path.join(os.path.dirname(__file__), "../database.db")
24
-
25
- connection_uri = os.environ.get(
26
- "FOOOCUS_DB_CONF", f"sqlite:///{default_sqlite_db_path}"
27
- )
28
-
29
-
30
- class GenerateRecord(Base):
31
- """
32
- GenerateRecord
33
-
34
- __tablename__ = 'generate_record'
35
- """
36
-
37
- __tablename__ = "generate_record"
38
-
39
- task_id: Mapped[str] = mapped_column(VARCHAR(255), nullable=False, primary_key=True)
40
- task_type: Mapped[str] = mapped_column(Text, nullable=False)
41
- result_url: Mapped[str] = mapped_column(Text, nullable=True)
42
- finish_reason: Mapped[str] = mapped_column(Text, nullable=True)
43
- date_time: Mapped[int] = mapped_column(Integer, nullable=False)
44
-
45
- prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
46
- negative_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
47
- style_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
48
- performance_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
49
- aspect_ratios_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
50
- base_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
51
- refiner_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
52
- refiner_switch: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
53
- loras: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
54
- image_number: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
55
- image_seed: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
56
- sharpness: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
57
- guidance_scale: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
58
- advanced_params: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
59
-
60
- input_image: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
61
- input_mask: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
62
- image_prompts: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
63
- inpaint_additional_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
64
- outpaint_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
65
- outpaint_distance_left: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
66
- outpaint_distance_right: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
67
- outpaint_distance_top: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
68
- outpaint_distance_bottom: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
69
- uov_method: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
70
- upscale_value: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
71
-
72
- webhook_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
73
- require_base64: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
74
- async_process: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
75
-
76
- def __repr__(self) -> str:
77
- return f"GenerateRecord(task_id={self.task_id!r}, task_type={self.task_type!r}, \
78
- result_url={self.result_url!r}, finish_reason={self.finish_reason!r}, date_time={self.date_time!r}, \
79
- prompt={self.prompt!r}, negative_prompt={self.negative_prompt!r}, style_selections={self.style_selections!r}, performance_selection={self.performance_selection!r}, \
80
- aspect_ratios_selection={self.aspect_ratios_selection!r}, base_model_name={self.base_model_name!r}, \
81
- refiner_model_name={self.refiner_model_name!r}, refiner_switch={self.refiner_switch!r}, loras={self.loras!r}, \
82
- image_number={self.image_number!r}, image_seed={self.image_seed!r}, sharpness={self.sharpness!r}, \
83
- guidance_scale={self.guidance_scale!r}, advanced_params={self.advanced_params!r}, input_image={self.input_image!r}, \
84
- input_mask={self.input_mask!r}, image_prompts={self.image_prompts!r}, inpaint_additional_prompt={self.inpaint_additional_prompt!r}, \
85
- outpaint_selections={self.outpaint_selections!r}, outpaint_distance_left={self.outpaint_distance_left!r}, outpaint_distance_right={self.outpaint_distance_right!r}, \
86
- outpaint_distance_top={self.outpaint_distance_top!r}, outpaint_distance_bottom={self.outpaint_distance_bottom!r}, uov_method={self.uov_method!r}, \
87
- upscale_value={self.upscale_value!r}, webhook_url={self.webhook_url!r}, require_base64={self.require_base64!r}, \
88
- async_process={self.async_process!r})"
89
-
90
-
91
- engine = create_engine(connection_uri)
92
-
93
- session = Session(engine)
94
- Base.metadata.create_all(engine, checkfirst=True)
95
- session.close()
96
-
97
-
98
- def convert_to_dict_list(obj_list: list[object]) -> list[dict]:
99
- """
100
- Convert a list of objects to a list of dictionaries.
101
- Args:
102
- obj_list:
103
-
104
- Returns:
105
- dict_list:
106
- """
107
- dict_list = []
108
- for obj in obj_list:
109
- # 将对象属性转化为字典键值对
110
- dict_obj = {}
111
- for attr, value in vars(obj).items():
112
- if (
113
- not callable(value)
114
- and not attr.startswith("__")
115
- and not attr.startswith("_")
116
- ):
117
- dict_obj[attr] = value
118
- task_info = {
119
- "task_id": obj.task_id,
120
- "task_type": obj.task_type,
121
- "result_url": obj.result_url,
122
- "finish_reason": obj.finish_reason,
123
- "date_time": datetime.fromtimestamp(obj.date_time).strftime(
124
- "%Y-%m-%d %H:%M:%S"
125
- ),
126
- }
127
- del dict_obj["task_id"]
128
- del dict_obj["task_type"]
129
- del dict_obj["result_url"]
130
- del dict_obj["finish_reason"]
131
- del dict_obj["date_time"]
132
- dict_list.append({"params": dict_obj, "task_info": task_info})
133
- return dict_list
134
-
135
-
136
- class MySQLAlchemy:
137
- """
138
- MySQLAlchemy, a toolkit for managing SQLAlchemy connections and sessions.
139
-
140
- :param uri: SQLAlchemy connection URI
141
- """
142
-
143
- def __init__(self, uri: str):
144
- # 'mysql+pymysql://{username}:{password}@{host}:{port}/{database}'
145
- self.engine = create_engine(uri)
146
- self.session = Session(self.engine)
147
-
148
- def store_history(self, record: dict) -> None:
149
- """
150
- Store history to database
151
- :param record:
152
- :return:
153
- """
154
- self.session.add_all([GenerateRecord(**record)])
155
- self.session.commit()
156
-
157
- def get_history(
158
- self,
159
- task_id: str = None,
160
- page: int = 0,
161
- page_size: int = 20,
162
- order_by: str = "date_time",
163
- ) -> list:
164
- """
165
- Get history from database
166
- :param task_id:
167
- :param page:
168
- :param page_size:
169
- :param order_by:
170
- :return:
171
- """
172
- if task_id is not None:
173
- res = (
174
- self.session.query(GenerateRecord)
175
- .filter(GenerateRecord.task_id == task_id)
176
- .all()
177
- )
178
- if len(res) == 0:
179
- return []
180
- return convert_to_dict_list(res)
181
-
182
- res = (
183
- self.session.query(GenerateRecord)
184
- .order_by(getattr(GenerateRecord, order_by).desc())
185
- .offset(page * page_size)
186
- .limit(page_size)
187
- .all()
188
- )
189
- if len(res) == 0:
190
- return []
191
- return convert_to_dict_list(res)
192
-
193
-
194
- db = MySQLAlchemy(uri=connection_uri)
195
-
196
-
197
- def req_to_dict(req: dict) -> dict:
198
- """
199
- Convert request to dictionary
200
- Args:
201
- req:
202
-
203
- Returns:
204
-
205
- """
206
- req["loras"] = [{"model_name": lora[0], "weight": lora[1]} for lora in req["loras"]]
207
- # req["advanced_params"] = dict(zip(adv_params_keys, req["advanced_params"]))
208
- req["image_prompts"] = [
209
- {"cn_img": "", "cn_stop": image[1], "cn_weight": image[2], "cn_type": image[3]}
210
- for image in req["image_prompts"]
211
- ]
212
- del req["inpaint_input_image"]
213
- del req["uov_input_image"]
214
- return req
215
-
216
-
217
- def add_history(
218
- params: dict, task_type: str, task_id: str, result_url: str, finish_reason: str
219
- ) -> None:
220
- """
221
- Store history to database
222
- Args:
223
- params:
224
- task_type:
225
- task_id:
226
- result_url:
227
- finish_reason:
228
-
229
- Returns:
230
-
231
- """
232
- adv = copy.deepcopy(params["advanced_params"])
233
- params["advanced_params"] = adv.__dict__
234
- params["date_time"] = int(time.time())
235
- params["task_type"] = task_type
236
- params["task_id"] = task_id
237
- params["result_url"] = result_url
238
- params["finish_reason"] = finish_reason
239
-
240
- del params["inpaint_input_image"]
241
- del params["uov_input_image"]
242
- del params["save_extension"]
243
- del params["save_meta"]
244
- del params["save_name"]
245
- del params["meta_scheme"]
246
-
247
- db.store_history(params)
248
-
249
-
250
- def query_history(
251
- task_id: str = None,
252
- page: int = 0,
253
- page_size: int = 20,
254
- order_by: str = "date_time"
255
- ) -> list:
256
- """
257
- Query history from database
258
- Args:
259
- task_id:
260
- page:
261
- page_size:
262
- order_by:
263
-
264
- Returns:
265
-
266
- """
267
- return db.get_history(
268
- task_id=task_id, page=page, page_size=page_size, order_by=order_by
269
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/task_queue.py DELETED
@@ -1,323 +0,0 @@
1
- """
2
- Task queue management
3
-
4
- This module provides classes and functions for managing the task queue.
5
-
6
- Classes:
7
- QueueTask: A class representing a task in the queue.
8
- TaskQueue: A class for managing the task queue.
9
- """
10
- import uuid
11
- import time
12
- from typing import List, Tuple
13
- import numpy as np
14
- import requests
15
-
16
- from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url
17
- from fooocusapi.utils.img_utils import narray_to_base64img
18
- from fooocusapi.utils.logger import logger
19
-
20
- from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason
21
- from fooocusapi.parameters import ImageGenerationParams
22
- from fooocusapi.models.common.task import TaskType
23
-
24
-
25
- class QueueTask:
26
- """
27
- A class representing a task in the queue.
28
-
29
- Attributes:
30
- job_id (str): The unique identifier for the task, generated by uuid.
31
- task_type (TaskType): The type of task.
32
- is_finished (bool): Indicates whether the task has been completed.
33
- finish_progress (int): The progress of the task completion.
34
- in_queue_mills (int): The time the task was added to the queue, in milliseconds.
35
- start_mills (int): The time the task started, in milliseconds.
36
- finish_mills (int): The time the task finished, in milliseconds.
37
- finish_with_error (bool): Indicates whether the task finished with an error.
38
- task_status (str): The status of the task.
39
- task_step_preview (str): A list of step previews for the task.
40
- task_result (List[ImageGenerationResult]): The result of the task.
41
- error_message (str): The error message, if any.
42
- webhook_url (str): The webhook URL, if any.
43
- """
44
-
45
- job_id: str
46
- task_type: TaskType
47
- req_param: ImageGenerationParams
48
- is_finished: bool = False
49
- finish_progress: int = 0
50
- in_queue_mills: int
51
- start_mills: int = 0
52
- finish_mills: int = 0
53
- finish_with_error: bool = False
54
- task_status: str | None = None
55
- task_step_preview: str | None = None
56
- task_result: List[ImageGenerationResult] = None
57
- error_message: str | None = None
58
- webhook_url: str | None = None # attribute for individual webhook_url
59
-
60
- def __init__(
61
- self,
62
- job_id: str,
63
- task_type: TaskType,
64
- req_param: ImageGenerationParams,
65
- webhook_url: str | None = None,
66
- ):
67
- self.job_id = job_id
68
- self.task_type = task_type
69
- self.req_param = req_param
70
- self.in_queue_mills = int(round(time.time() * 1000))
71
- self.webhook_url = webhook_url
72
-
73
- def set_progress(self, progress: int, status: str | None):
74
- """
75
- Set progress and status
76
- Arguments:
77
- progress {int} -- progress
78
- status {str} -- status
79
- """
80
- progress = min(progress, 100)
81
- self.finish_progress = progress
82
- self.task_status = status
83
-
84
- def set_step_preview(self, task_step_preview: str | None):
85
- """set step preview
86
- Set step preview
87
- Arguments:
88
- task_step_preview {str} -- step preview
89
- """
90
- self.task_step_preview = task_step_preview
91
-
92
- def set_result(
93
- self,
94
- task_result: List[ImageGenerationResult],
95
- finish_with_error: bool,
96
- error_message: str | None = None,
97
- ):
98
- """set result
99
- Set task result
100
- Arguments:
101
- task_result {List[ImageGenerationResult]} -- task result
102
- finish_with_error {bool} -- finish with error
103
- error_message {str} -- error message
104
- """
105
- if not finish_with_error:
106
- self.finish_progress = 100
107
- self.task_status = "Finished"
108
- self.task_result = task_result
109
- self.finish_with_error = finish_with_error
110
- self.error_message = error_message
111
-
112
- def __str__(self) -> str:
113
- return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\
114
- is_finished={self.is_finished}, finished_progress={self.finish_progress}, \
115
- in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \
116
- finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \
117
- error_message={self.error_message}, task_status={self.task_status}, \
118
- task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})"
119
-
120
-
121
- class TaskQueue:
122
- """
123
- TaskQueue is a queue of tasks that are waiting to be processed.
124
-
125
- Attributes:
126
- queue: List[QueueTask]
127
- history: List[QueueTask]
128
- last_job_id: str
129
- webhook_url: str
130
- persistent: bool
131
- """
132
-
133
- queue: List[QueueTask] = []
134
- history: List[QueueTask] = []
135
- last_job_id: str = None
136
- webhook_url: str | None = None
137
- persistent: bool = False
138
-
139
- def __init__(
140
- self,
141
- queue_size: int,
142
- history_size: int,
143
- webhook_url: str | None = None,
144
- persistent: bool | None = False,
145
- ):
146
- self.queue_size = queue_size
147
- self.history_size = history_size
148
- self.webhook_url = webhook_url
149
- self.persistent = False if persistent is None else persistent
150
-
151
- def add_task(
152
- self,
153
- task_type: TaskType,
154
- req_param: ImageGenerationParams,
155
- webhook_url: str | None = None,
156
- ) -> QueueTask | None:
157
- """
158
- Create and add task to queue
159
- :param task_type: task type
160
- :param req_param: request parameters
161
- :param webhook_url: webhook url
162
- :returns: The created task's job_id, or None if reach the queue size limit
163
- """
164
- if len(self.queue) >= self.queue_size:
165
- return None
166
-
167
- if isinstance(req_param, dict):
168
- req_param = ImageGenerationParams(**req_param)
169
-
170
- job_id = str(uuid.uuid4())
171
- task = QueueTask(
172
- job_id=job_id,
173
- task_type=task_type,
174
- req_param=req_param,
175
- webhook_url=webhook_url,
176
- )
177
- self.queue.append(task)
178
- self.last_job_id = job_id
179
- return task
180
-
181
- def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None:
182
- """
183
- Get task by job_id
184
- :param job_id: job id
185
- :param include_history: whether to include history tasks
186
- :returns: The task with the given job_id, or None if not found
187
- """
188
- for task in self.queue:
189
- if task.job_id == job_id:
190
- return task
191
-
192
- if include_history:
193
- for task in self.history:
194
- if task.job_id == job_id:
195
- return task
196
-
197
- return None
198
-
199
- def is_task_ready_to_start(self, job_id: str) -> bool:
200
- """
201
- Check if the task is ready to start
202
- :param job_id: job id
203
- :returns: True if the task is ready to start, False otherwise
204
- """
205
- task = self.get_task(job_id)
206
- if task is None:
207
- return False
208
-
209
- return self.queue[0].job_id == job_id
210
-
211
- def is_task_finished(self, job_id: str) -> bool:
212
- """
213
- Check if the task is finished
214
- :param job_id: job id
215
- :returns: True if the task is finished, False otherwise
216
- """
217
- task = self.get_task(job_id, True)
218
- if task is None:
219
- return False
220
-
221
- return task.is_finished
222
-
223
- def start_task(self, job_id: str):
224
- """
225
- Start task by job_id
226
- :param job_id: job id
227
- """
228
- task = self.get_task(job_id)
229
- if task is not None:
230
- task.start_mills = int(round(time.time() * 1000))
231
-
232
- def finish_task(self, job_id: str):
233
- """
234
- Finish task by job_id
235
- :param job_id: job id
236
- """
237
- task = self.get_task(job_id)
238
- if task is not None:
239
- task.is_finished = True
240
- task.finish_mills = int(round(time.time() * 1000))
241
-
242
- # Use the task's webhook_url if available, else use the default
243
- webhook_url = task.webhook_url or self.webhook_url
244
-
245
- data = {"job_id": task.job_id, "job_result": []}
246
-
247
- if isinstance(task.task_result, List):
248
- for item in task.task_result:
249
- data["job_result"].append(
250
- {
251
- "url": get_file_serve_url(item.im) if item.im else None,
252
- "seed": item.seed if item.seed else "-1",
253
- }
254
- )
255
-
256
- # Send webhook
257
- if task.is_finished and webhook_url:
258
- try:
259
- res = requests.post(webhook_url, json=data, timeout=15)
260
- print(f"Call webhook response status: {res.status_code}")
261
- except Exception as e:
262
- print("Call webhook error:", e)
263
-
264
- # Move task to history
265
- self.queue.remove(task)
266
- self.history.append(task)
267
-
268
- # save history to database
269
- if self.persistent:
270
- from fooocusapi.sql_client import add_history
271
-
272
- add_history(
273
- params=task.req_param.to_dict(),
274
- task_type=task.task_type.value,
275
- task_id=task.job_id,
276
- result_url=",".join([job["url"] for job in data["job_result"]]),
277
- finish_reason=task.task_result[0].finish_reason.value,
278
- )
279
-
280
- # Clean history
281
- if len(self.history) > self.history_size != 0:
282
- removed_task = self.history.pop(0)
283
- if isinstance(removed_task.task_result, List):
284
- for item in removed_task.task_result:
285
- if (
286
- isinstance(item, ImageGenerationResult)
287
- and item.finish_reason == GenerationFinishReason.success
288
- and item.im is not None
289
- ):
290
- delete_output_file(item.im)
291
- logger.std_info(
292
- f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}"
293
- )
294
-
295
-
296
- class TaskOutputs:
297
- """
298
- TaskOutputs is a container for task outputs
299
- """
300
-
301
- outputs = []
302
-
303
- def __init__(self, task: QueueTask):
304
- self.task = task
305
-
306
- def append(self, args: List[any]):
307
- """
308
- Append output to task outputs list
309
- :param args: output arguments
310
- """
311
- self.outputs.append(args)
312
- if len(args) >= 2:
313
- if (
314
- args[0] == "preview"
315
- and isinstance(args[1], Tuple)
316
- and len(args[1]) >= 2
317
- ):
318
- number = args[1][0]
319
- text = args[1][1]
320
- self.task.set_progress(number, text)
321
- if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray):
322
- base64_preview_img = narray_to_base64img(args[1][2])
323
- self.task.set_step_preview(base64_preview_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/api_utils.py DELETED
@@ -1,291 +0,0 @@
1
- """some utils for api"""
2
- from typing import List
3
-
4
- from fastapi import Response
5
- from fastapi.security import APIKeyHeader
6
- from fastapi import HTTPException, Security
7
-
8
- from modules import flags
9
- from modules import config
10
- from modules.sdxl_styles import legal_style_names
11
-
12
- from fooocusapi.args import args
13
- from fooocusapi.utils.img_utils import read_input_image
14
- from fooocusapi.utils.file_utils import (
15
- get_file_serve_url,
16
- output_file_to_base64img,
17
- output_file_to_bytesimg
18
- )
19
- from fooocusapi.utils.logger import logger
20
- from fooocusapi.models.common.requests import (
21
- CommonRequest as Text2ImgRequest
22
- )
23
- from fooocusapi.models.common.response import (
24
- AsyncJobResponse,
25
- AsyncJobStage,
26
- GeneratedImageResult
27
- )
28
- from fooocusapi.models.requests_v1 import (
29
- ImgInpaintOrOutpaintRequest,
30
- ImgPromptRequest,
31
- ImgUpscaleOrVaryRequest
32
- )
33
- from fooocusapi.models.requests_v2 import (
34
- Text2ImgRequestWithPrompt,
35
- ImgInpaintOrOutpaintRequestJson,
36
- ImgUpscaleOrVaryRequestJson,
37
- ImgPromptRequestJson
38
- )
39
- from fooocusapi.models.common.task import (
40
- ImageGenerationResult,
41
- GenerationFinishReason
42
- )
43
- from fooocusapi.configs.default import (
44
- default_inpaint_engine_version,
45
- default_sampler,
46
- default_scheduler,
47
- default_base_model_name,
48
- default_refiner_model_name
49
- )
50
-
51
- from fooocusapi.parameters import ImageGenerationParams
52
- from fooocusapi.task_queue import QueueTask
53
-
54
-
55
- api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)
56
-
57
-
58
- def api_key_auth(apikey: str = Security(api_key_header)):
59
- """
60
- Check if the API key is valid, API key is not required if no API key is set
61
- Args:
62
- apikey: API key
63
- returns:
64
- None if API key is not set, otherwise raise HTTPException
65
- """
66
- if args.apikey is None:
67
- return # Skip API key check if no API key is set
68
- if apikey != args.apikey:
69
- raise HTTPException(status_code=403, detail="Forbidden")
70
-
71
-
72
- def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
73
- """
74
- Convert Request to ImageGenerationParams
75
- Args:
76
- req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest
77
- returns:
78
- ImageGenerationParams
79
- """
80
- config.update_files()
81
- if req.base_model_name is not None:
82
- if req.base_model_name not in config.model_filenames:
83
- logger.std_warn(f"[Warning] Wrong base_model_name input: {req.base_model_name}, using default")
84
- req.base_model_name = default_base_model_name
85
-
86
- if req.refiner_model_name is not None and req.refiner_model_name != 'None':
87
- if req.refiner_model_name not in config.model_filenames:
88
- logger.std_warn(f"[Warning] Wrong refiner_model_name input: {req.refiner_model_name}, using default")
89
- req.refiner_model_name = default_refiner_model_name
90
-
91
- for lora in req.loras:
92
- if lora.model_name != 'None' and lora.model_name not in config.lora_filenames:
93
- logger.std_warn(f"[Warning] Wrong lora model_name input: {lora.model_name}, using 'None'")
94
- lora.model_name = 'None'
95
-
96
- prompt = req.prompt
97
- negative_prompt = req.negative_prompt
98
- style_selections = [
99
- s for s in req.style_selections if s in legal_style_names]
100
- performance_selection = req.performance_selection.value
101
- aspect_ratios_selection = req.aspect_ratios_selection
102
- image_number = req.image_number
103
- image_seed = None if req.image_seed == -1 else req.image_seed
104
- sharpness = req.sharpness
105
- guidance_scale = req.guidance_scale
106
- base_model_name = req.base_model_name
107
- refiner_model_name = req.refiner_model_name
108
- refiner_switch = req.refiner_switch
109
- loras = [(lora.model_name, lora.weight) for lora in req.loras]
110
- uov_input_image = None
111
- if not isinstance(req, Text2ImgRequestWithPrompt):
112
- if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
113
- uov_input_image = read_input_image(req.input_image)
114
- uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value
115
- upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value
116
- outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [
117
- s.value for s in req.outpaint_selections]
118
- outpaint_distance_left = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left
119
- outpaint_distance_right = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right
120
- outpaint_distance_top = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top
121
- outpaint_distance_bottom = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom
122
-
123
- if refiner_model_name == '':
124
- refiner_model_name = 'None'
125
-
126
- inpaint_input_image = None
127
- inpaint_additional_prompt = None
128
- if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None:
129
- inpaint_additional_prompt = req.inpaint_additional_prompt
130
- input_image = read_input_image(req.input_image)
131
- input_mask = None
132
- if req.input_mask is not None:
133
- input_mask = read_input_image(req.input_mask)
134
- inpaint_input_image = {
135
- 'image': input_image,
136
- 'mask': input_mask
137
- }
138
-
139
- image_prompts = []
140
- if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)):
141
- # Auto set mixing_image_prompt_and_inpaint to True
142
- if len(req.image_prompts) > 0 and uov_input_image is not None:
143
- print("[INFO] Mixing image prompt and vary upscale is set to True")
144
- req.advanced_params.mixing_image_prompt_and_vary_upscale = True
145
- elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None:
146
- print("[INFO] Mixing image prompt and inpaint is set to True")
147
- req.advanced_params.mixing_image_prompt_and_inpaint = True
148
-
149
- for img_prompt in req.image_prompts:
150
- if img_prompt.cn_img is not None:
151
- cn_img = read_input_image(img_prompt.cn_img)
152
- if img_prompt.cn_stop is None or img_prompt.cn_stop == 0:
153
- img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0]
154
- if img_prompt.cn_weight is None or img_prompt.cn_weight == 0:
155
- img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1]
156
- image_prompts.append(
157
- (cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value))
158
-
159
- advanced_params = None
160
- if req.advanced_params is not None:
161
- adp = req.advanced_params
162
-
163
- if adp.refiner_swap_method not in ['joint', 'separate', 'vae']:
164
- print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default")
165
- adp.refiner_swap_method = 'joint'
166
-
167
- if adp.sampler_name not in flags.sampler_list:
168
- print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default")
169
- adp.sampler_name = default_sampler
170
-
171
- if adp.scheduler_name not in flags.scheduler_list:
172
- print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default")
173
- adp.scheduler_name = default_scheduler
174
-
175
- if adp.inpaint_engine not in flags.inpaint_engine_versions:
176
- print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default")
177
- adp.inpaint_engine = default_inpaint_engine_version
178
-
179
- advanced_params = adp
180
-
181
- return ImageGenerationParams(
182
- prompt=prompt,
183
- negative_prompt=negative_prompt,
184
- style_selections=style_selections,
185
- performance_selection=performance_selection,
186
- aspect_ratios_selection=aspect_ratios_selection,
187
- image_number=image_number,
188
- image_seed=image_seed,
189
- sharpness=sharpness,
190
- guidance_scale=guidance_scale,
191
- base_model_name=base_model_name,
192
- refiner_model_name=refiner_model_name,
193
- refiner_switch=refiner_switch,
194
- loras=loras,
195
- uov_input_image=uov_input_image,
196
- uov_method=uov_method,
197
- upscale_value=upscale_value,
198
- outpaint_selections=outpaint_selections,
199
- outpaint_distance_left=outpaint_distance_left,
200
- outpaint_distance_right=outpaint_distance_right,
201
- outpaint_distance_top=outpaint_distance_top,
202
- outpaint_distance_bottom=outpaint_distance_bottom,
203
- inpaint_input_image=inpaint_input_image,
204
- inpaint_additional_prompt=inpaint_additional_prompt,
205
- image_prompts=image_prompts,
206
- advanced_params=advanced_params,
207
- save_meta=req.save_meta,
208
- meta_scheme=req.meta_scheme,
209
- save_name=req.save_name,
210
- save_extension=req.save_extension,
211
- require_base64=req.require_base64,
212
- )
213
-
214
-
215
- def generate_async_output(
216
- task: QueueTask,
217
- require_step_preview: bool = False) -> AsyncJobResponse:
218
- """
219
- Generate output for async job
220
- Arguments:
221
- task: QueueTask
222
- require_step_preview: bool
223
- Returns:
224
- AsyncJobResponse
225
- """
226
- job_stage = AsyncJobStage.running
227
- job_result = None
228
-
229
- if task.start_mills == 0:
230
- job_stage = AsyncJobStage.waiting
231
-
232
- if task.is_finished:
233
- if task.finish_with_error:
234
- job_stage = AsyncJobStage.error
235
- elif task.task_result is not None:
236
- job_stage = AsyncJobStage.success
237
- job_result = generate_image_result_output(task.task_result, task.req_param.require_base64)
238
-
239
- result = AsyncJobResponse(
240
- job_id=task.job_id,
241
- job_type=task.task_type,
242
- job_stage=job_stage,
243
- job_progress=task.finish_progress,
244
- job_status=task.task_status,
245
- job_step_preview=task.task_step_preview if require_step_preview else None,
246
- job_result=job_result)
247
- return result
248
-
249
-
250
- def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
251
- """
252
- Generate streaming output for image generation results.
253
- Args:
254
- results (List[ImageGenerationResult]): List of image generation results.
255
- Returns:
256
- Response: Streaming response object, bytes image.
257
- """
258
- if len(results) == 0:
259
- return Response(status_code=500)
260
- result = results[0]
261
- if result.finish_reason == GenerationFinishReason.queue_is_full:
262
- return Response(status_code=409, content=result.finish_reason.value)
263
- if result.finish_reason == GenerationFinishReason.user_cancel:
264
- return Response(status_code=400, content=result.finish_reason.value)
265
- if result.finish_reason == GenerationFinishReason.error:
266
- return Response(status_code=500, content=result.finish_reason.value)
267
-
268
- img_bytes = output_file_to_bytesimg(results[0].im)
269
- return Response(img_bytes, media_type='image/png')
270
-
271
-
272
- def generate_image_result_output(
273
- results: List[ImageGenerationResult],
274
- require_base64: bool) -> List[GeneratedImageResult]:
275
- """
276
- Generate image result output
277
- Arguments:
278
- results: List[ImageGenerationResult]
279
- require_base64: bool
280
- Returns:
281
- List[GeneratedImageResult]
282
- """
283
- results = [
284
- GeneratedImageResult(
285
- base64=output_file_to_base64img(item.im) if require_base64 else None,
286
- url=get_file_serve_url(item.im),
287
- seed=str(item.seed),
288
- finish_reason=item.finish_reason
289
- ) for item in results
290
- ]
291
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/call_worker.py DELETED
@@ -1,97 +0,0 @@
1
- """function for call generate worker"""
2
- from typing import List
3
- from fastapi import Response
4
-
5
- from fooocusapi.models.common.requests import (
6
- CommonRequest as Text2ImgRequest
7
- )
8
- from fooocusapi.models.common.response import (
9
- AsyncJobResponse,
10
- GeneratedImageResult
11
- )
12
- from fooocusapi.models.common.task import (
13
- GenerationFinishReason,
14
- ImageGenerationResult,
15
- AsyncJobStage,
16
- TaskType
17
- )
18
- from fooocusapi.utils.api_utils import (
19
- req_to_params,
20
- generate_async_output,
21
- generate_streaming_output,
22
- generate_image_result_output
23
- )
24
- from fooocusapi.models.requests_v1 import (
25
- ImgUpscaleOrVaryRequest,
26
- ImgPromptRequest,
27
- ImgInpaintOrOutpaintRequest
28
- )
29
- from fooocusapi.models.requests_v2 import (
30
- ImgInpaintOrOutpaintRequestJson,
31
- ImgPromptRequestJson,
32
- ImgUpscaleOrVaryRequestJson
33
- )
34
- from fooocusapi.worker import worker_queue, blocking_get_task_result
35
-
36
-
37
- def get_task_type(req: Text2ImgRequest) -> TaskType:
38
- """return task type"""
39
- if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
40
- return TaskType.img_uov
41
- if isinstance(req, (ImgPromptRequest, ImgPromptRequestJson)):
42
- return TaskType.img_prompt
43
- if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)):
44
- return TaskType.img_inpaint_outpaint
45
- return TaskType.text_2_img
46
-
47
-
48
- def call_worker(req: Text2ImgRequest, accept: str) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
49
- """call generation worker"""
50
- if accept == 'image/png':
51
- streaming_output = True
52
- # image_number auto set to 1 in streaming mode
53
- req.image_number = 1
54
- else:
55
- streaming_output = False
56
-
57
- task_type = get_task_type(req)
58
- params = req_to_params(req)
59
- async_task = worker_queue.add_task(task_type, params, req.webhook_url)
60
-
61
- if async_task is None:
62
- # add to worker queue failed
63
- failure_results = [
64
- ImageGenerationResult(
65
- im=None,
66
- seed='',
67
- finish_reason=GenerationFinishReason.queue_is_full
68
- )]
69
-
70
- if streaming_output:
71
- return generate_streaming_output(failure_results)
72
- if req.async_process:
73
- return AsyncJobResponse(
74
- job_id='',
75
- job_type=get_task_type(req),
76
- job_stage=AsyncJobStage.error,
77
- job_progress=0,
78
- job_status=None,
79
- job_step_preview=None,
80
- job_result=[GeneratedImageResult(
81
- base64=None,
82
- url=None,
83
- seed='',
84
- finish_reason=GenerationFinishReason.queue_is_full
85
- )])
86
- return generate_image_result_output(failure_results, False)
87
-
88
- if req.async_process:
89
- # return async response directly
90
- return generate_async_output(async_task)
91
-
92
- # blocking get generation result
93
- results = blocking_get_task_result(async_task.job_id)
94
-
95
- if streaming_output:
96
- return generate_streaming_output(results)
97
- return generate_image_result_output(results, req.require_base64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/file_utils.py DELETED
@@ -1,143 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- """ File utils
4
-
5
- Use for managing generated files
6
-
7
- @file: file_utils.py
8
- @author: Konie
9
- @update: 2024-03-22
10
- """
11
- import base64
12
- import datetime
13
- from io import BytesIO
14
- import os
15
- import json
16
- from pathlib import Path
17
- import numpy as np
18
- from PIL import Image
19
- from PIL.PngImagePlugin import PngInfo
20
-
21
- from fooocusapi.utils.logger import logger
22
-
23
-
24
- output_dir = os.path.abspath(os.path.join(
25
- os.path.dirname(__file__), '../..', 'outputs', 'files'))
26
- os.makedirs(output_dir, exist_ok=True)
27
-
28
- STATIC_SERVER_BASE = 'http://127.0.0.1:8888/files/'
29
-
30
-
31
- def save_output_file(
32
- img: np.ndarray,
33
- image_meta: dict = None,
34
- image_name: str = '',
35
- extension: str = 'png') -> str:
36
- """
37
- Save np image to file
38
- Args:
39
- img: np.ndarray image to save
40
- image_meta: dict of image metadata
41
- image_name: str of image name
42
- extension: str of image extension
43
- Returns:
44
- str of file name
45
- """
46
- current_time = datetime.datetime.now()
47
- date_string = current_time.strftime("%Y-%m-%d")
48
-
49
- filename = os.path.join(date_string, image_name + '.' + extension)
50
- file_path = os.path.join(output_dir, filename)
51
-
52
- if extension not in ['png', 'jpg', 'webp']:
53
- extension = 'png'
54
- image_format = Image.registered_extensions()['.'+extension]
55
-
56
- if image_meta is None:
57
- image_meta = {}
58
-
59
- meta = None
60
- if extension == 'png'and image_meta != {}:
61
- meta = PngInfo()
62
- meta.add_text("parameters", json.dumps(image_meta))
63
- meta.add_text("fooocus_scheme", image_meta['metadata_scheme'])
64
-
65
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
66
- Image.fromarray(img).save(
67
- file_path,
68
- format=image_format,
69
- pnginfo=meta,
70
- optimize=True)
71
- return Path(filename).as_posix()
72
-
73
-
74
- def delete_output_file(filename: str):
75
- """
76
- Delete files specified in the output directory
77
- Args:
78
- filename: str of file name
79
- """
80
- file_path = os.path.join(output_dir, filename)
81
- if not os.path.exists(file_path) or not os.path.isfile(file_path):
82
- logger.std_warn(f'[Fooocus API] {filename} not exists or is not a file')
83
- try:
84
- os.remove(file_path)
85
- logger.std_info(f'[Fooocus API] Delete output file: {filename}')
86
- except OSError:
87
- logger.std_error(f'[Fooocus API] Delete output file failed: {filename}')
88
-
89
-
90
- def output_file_to_base64img(filename: str | None) -> str | None:
91
- """
92
- Convert an image file to a base64 string.
93
- Args:
94
- filename: str of file name
95
- return: str of base64 string
96
- """
97
- if filename is None:
98
- return None
99
- file_path = os.path.join(output_dir, filename)
100
- if not os.path.exists(file_path) or not os.path.isfile(file_path):
101
- return None
102
-
103
- ext = filename.split('.')[-1]
104
- if ext.lower() not in ['png', 'jpg', 'webp', 'jpeg']:
105
- ext = 'png'
106
- img = Image.open(file_path)
107
- output_buffer = BytesIO()
108
- img.save(output_buffer, format=ext.upper())
109
- byte_data = output_buffer.getvalue()
110
- base64_str = base64.b64encode(byte_data).decode('utf-8')
111
- return f"data:image/{ext};base64," + base64_str
112
-
113
-
114
- def output_file_to_bytesimg(filename: str | None) -> bytes | None:
115
- """
116
- Convert an image file to a bytes string.
117
- Args:
118
- filename: str of file name
119
- return: bytes of image data
120
- """
121
- if filename is None:
122
- return None
123
- file_path = os.path.join(output_dir, filename)
124
- if not os.path.exists(file_path) or not os.path.isfile(file_path):
125
- return None
126
-
127
- img = Image.open(file_path)
128
- output_buffer = BytesIO()
129
- img.save(output_buffer, format='PNG')
130
- byte_data = output_buffer.getvalue()
131
- return byte_data
132
-
133
-
134
- def get_file_serve_url(filename: str | None) -> str | None:
135
- """
136
- Get the static serve url of an image file.
137
- Args:
138
- filename: str of file name
139
- return: str of static serve url
140
- """
141
- if filename is None:
142
- return None
143
- return STATIC_SERVER_BASE + filename.replace('\\', '/')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/img_utils.py DELETED
@@ -1,198 +0,0 @@
1
- """
2
- Image process utils. Used to verify, convert and store Images.
3
-
4
- @file: img_utils.py
5
- @author: Konie
6
- @update: 2024-03-23
7
- """
8
- import base64
9
- from io import BytesIO
10
- from fastapi import UploadFile
11
- from PIL import Image
12
-
13
- import requests
14
- import numpy as np
15
-
16
-
17
- def upload2base64(image: UploadFile) -> str | None:
18
- """
19
- Convert UploadFile obj to base64 string
20
- Args:
21
- image (UploadFile): UploadFile obj
22
- Returns:
23
- str: base64 string, None for None
24
- """
25
- if image is None:
26
- return None
27
- image_bytes = image.file.read()
28
- image_base64 = base64.b64encode(image_bytes).decode("utf-8")
29
- return image_base64
30
-
31
-
32
- def narray_to_base64img(narray: np.ndarray) -> str | None:
33
- """
34
- Convert numpy array to base64 image string.
35
- Args:
36
- narray: numpy array
37
- Returns:
38
- base64 image string
39
- """
40
- if narray is None:
41
- return None
42
-
43
- img = Image.fromarray(narray)
44
- output_buffer = BytesIO()
45
- img.save(output_buffer, format='PNG')
46
- byte_data = output_buffer.getvalue()
47
- base64_str = base64.b64encode(byte_data).decode('utf-8')
48
- return base64_str
49
-
50
-
51
- def narray_to_bytesimg(narray) -> bytes | None:
52
- """
53
- Convert numpy array to bytes image.
54
- Args:
55
- narray: numpy array
56
- Returns:
57
- bytes image
58
- """
59
- if narray is None:
60
- return None
61
-
62
- img = Image.fromarray(narray)
63
- output_buffer = BytesIO()
64
- img.save(output_buffer, format='PNG')
65
- byte_data = output_buffer.getvalue()
66
- return byte_data
67
-
68
-
69
- def read_input_image(input_image: UploadFile | str | None) -> np.ndarray | None:
70
- """
71
- Read input image from UploadFile or base64 string.
72
- Args:
73
- input_image: UploadFile, or base64 image string, or None
74
- Returns:
75
- numpy array of image
76
- """
77
- if input_image is None or input_image == '':
78
- return None
79
- if isinstance(input_image, str):
80
- input_image_bytes = base64.b64decode(input_image)
81
- else:
82
- input_image_bytes = input_image.file.read()
83
- pil_image = Image.open(BytesIO(input_image_bytes))
84
- image = np.array(pil_image)
85
- return image
86
-
87
-
88
- def base64_to_stream(image: str) -> UploadFile | None:
89
- """
90
- Convert base64 image string to UploadFile.
91
- Args:
92
- image: base64 image string
93
- Returns:
94
- UploadFile or None
95
- """
96
- if image in ['', None, 'None', 'none', 'string', 'null']:
97
- return None
98
- if image.startswith('http'):
99
- return get_check_image(url=image)
100
- if image.startswith('data:image'):
101
- image = image.split(sep=',', maxsplit=1)[1]
102
- image_bytes = base64.b64decode(image)
103
- byte_stream = BytesIO()
104
- byte_stream.write(image_bytes)
105
- byte_stream.seek(0)
106
- return UploadFile(file=byte_stream)
107
-
108
-
109
- def get_check_image(url: str) -> UploadFile | None:
110
- """
111
- Get image from url and check if it's valid.
112
- Args:
113
- url: image url
114
- Returns:
115
- UploadFile or None
116
- """
117
- if url == '':
118
- return None
119
- headers = {
120
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
121
- }
122
- try:
123
- response = requests.get(url, headers=headers, timeout=10)
124
- binary_image = response.content
125
- except Exception:
126
- return None
127
- try:
128
- buffer = BytesIO(binary_image)
129
- Image.open(buffer) # This validates the image
130
- except Exception:
131
- return None
132
- byte_stream = BytesIO()
133
- byte_stream.write(binary_image)
134
- byte_stream.seek(0)
135
- return UploadFile(file=byte_stream)
136
-
137
-
138
- def bytes_image_to_io(binary_image: bytes) -> BytesIO | None:
139
- """
140
- Convert bytes image to BytesIO.
141
- Args:
142
- binary_image: bytes image
143
- Returns:
144
- BytesIO or None
145
- """
146
- try:
147
- buffer = BytesIO(binary_image)
148
- Image.open(buffer)
149
- except Exception:
150
- return None
151
- byte_stream = BytesIO()
152
- byte_stream.write(binary_image)
153
- byte_stream.seek(0)
154
- return byte_stream
155
-
156
-
157
- def bytes_to_base64img(byte_data: bytes) -> str | None:
158
- """
159
- Convert bytes image to base64 image string.
160
- Args:
161
- byte_data: bytes image
162
- Returns:
163
- base64 image string or None
164
- """
165
- if byte_data is None:
166
- return None
167
-
168
- base64_str = base64.b64encode(byte_data).decode('utf-8')
169
- return base64_str
170
-
171
-
172
- def base64_to_bytesimg(base64_str: str) -> bytes | None:
173
- """
174
- Convert base64 image string to bytes image.
175
- Args:
176
- base64_str: base64 image string
177
- Returns:
178
- bytes image or None
179
- """
180
- if base64_str == '':
181
- return None
182
- bytes_image = base64.b64decode(base64_str)
183
- return bytes_image
184
-
185
-
186
- def base64_to_narray(base64_str: str) -> np.ndarray | None:
187
- """
188
- Convert base64 image string to numpy array.
189
- Args:
190
- base64_str: base64 image string
191
- Returns:
192
- numpy array or None
193
- """
194
- if base64_str == '':
195
- return None
196
- bytes_image = base64.b64decode(base64_str)
197
- image = np.frombuffer(bytes_image, np.uint8)
198
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/logger.py DELETED
@@ -1,132 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- """ A simply logger.
4
-
5
- This module is used to log the program.
6
-
7
- @file: logger.py
8
- @author: mrhan1993
9
- @update: 2024-03-22
10
- """
11
- import logging
12
- import os
13
- import sys
14
-
15
- try:
16
- from colorlog import ColoredFormatter
17
- except ImportError:
18
- from fooocusapi.utils.tools import run_pip
19
- run_pip(
20
- command="install colorlog",
21
- desc="Install colorlog for logger.",
22
- live=True
23
- )
24
- finally:
25
- from colorlog import ColoredFormatter
26
-
27
-
28
- own_path = os.path.dirname(os.path.abspath(__file__))
29
- log_dir = "logs"
30
- default_log_path = os.path.join(own_path, '../../', log_dir)
31
-
32
- std_formatter = ColoredFormatter(
33
- fmt="%(log_color)s[%(asctime)s] %(levelname)-8s%(reset)s %(blue)s%(message)s",
34
- datefmt='%Y-%m-%d %H:%M:%S',
35
- reset=True,
36
- log_colors={
37
- 'DEBUG': 'cyan',
38
- 'INFO': 'green',
39
- 'WARNING': 'yellow',
40
- 'ERROR': 'red',
41
- 'CRITICAL': 'red,bg_white',
42
- },
43
- secondary_log_colors={},
44
- style='%'
45
- )
46
-
47
- file_formatter = ColoredFormatter(
48
- fmt="[%(asctime)s] %(levelname)-8s%(reset)s %(message)s",
49
- datefmt='%Y-%m-%d %H:%M:%S',
50
- reset=True,
51
- no_color=True,
52
- style='%'
53
- )
54
-
55
-
56
- class ConfigLogger:
57
- """
58
- Configure logger.
59
- :param log_path: log file path, better absolute path
60
- :param std_format: stdout log format
61
- :param file_format: file log format
62
- """
63
- def __init__(self,
64
- log_path: str = default_log_path,
65
- std_format: ColoredFormatter = std_formatter,
66
- file_format: ColoredFormatter = file_formatter) -> None:
67
- self.log_path = log_path
68
- self.std_format = std_format
69
- self.file_format = file_format
70
-
71
-
72
- class Logger:
73
- """
74
- A simple logger.
75
- :param log_name: log name
76
- :param config: config logger
77
- """
78
- def __init__(self, log_name, config: ConfigLogger = ConfigLogger()):
79
- log_path = config.log_path
80
- err_log_path = os.path.join(str(log_path), f"{log_name}_error.log")
81
- info_log_path = os.path.join(str(log_path), f"{log_name}_info.log")
82
- if not os.path.exists(log_path):
83
- os.makedirs(log_path, exist_ok=True)
84
-
85
- self._file_logger = logging.getLogger(log_name)
86
- self._file_logger.setLevel("INFO")
87
-
88
- self._std_logger = logging.getLogger()
89
- self._std_logger.setLevel("INFO")
90
-
91
- # 创建一个ERROR级别的handler,将日志记录到error.log文件中
92
- error_handler = logging.FileHandler(err_log_path, encoding='utf-8')
93
- error_handler.setLevel(logging.ERROR)
94
-
95
- # 创建一个INFO级别的handler,将日志记录到info.log文件中
96
- info_handler = logging.FileHandler(info_log_path, encoding='utf-8')
97
- info_handler.setLevel(logging.INFO)
98
-
99
- # 创建一个 stream handler
100
- stream_handler = logging.StreamHandler(sys.stdout)
101
-
102
- error_handler.setFormatter(config.file_format)
103
- info_handler.setFormatter(config.file_format)
104
- stream_handler.setFormatter(config.std_format)
105
-
106
- # 将handler添加到logger中
107
- self._file_logger.addHandler(error_handler)
108
- self._file_logger.addHandler(info_handler)
109
- self._std_logger.addHandler(stream_handler)
110
-
111
- def file_error(self, message):
112
- """file error log"""
113
- self._file_logger.error(message)
114
-
115
- def file_info(self, message):
116
- """file info log"""
117
- self._file_logger.info(message)
118
-
119
- def std_info(self, message):
120
- """std info log"""
121
- self._std_logger.info(message)
122
-
123
- def std_warn(self, message):
124
- """std warn log"""
125
- self._std_logger.warning(message)
126
-
127
- def std_error(self, message):
128
- """std error log"""
129
- self._std_logger.error(message)
130
-
131
-
132
- logger = Logger(log_name="fooocus_api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/lora_manager.py DELETED
@@ -1,71 +0,0 @@
1
- import hashlib
2
- import os
3
- import requests
4
- import tarfile
5
-
6
- def _hash_url(url):
7
- """Generates a hash value for a given URL."""
8
- return hashlib.md5(url.encode('utf-8')).hexdigest()
9
-
10
- class LoraManager:
11
- """
12
- Manager loras from url
13
- """
14
- def __init__(self):
15
- self.cache_dir = os.path.join(
16
- os.path.dirname(os.path.realpath(__file__)),
17
- '../../',
18
- 'repositories/Fooocus/models/loras')
19
-
20
- def _download_lora(self, url):
21
- """
22
- Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file.
23
- """
24
- url_hash = _hash_url(url)
25
- file_ext = url.split('.')[-1]
26
- filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}")
27
-
28
- if not os.path.exists(filepath):
29
- print(f"Start download for: {url}")
30
-
31
- try:
32
- response = requests.get(url, timeout=10, stream=True)
33
- response.raise_for_status()
34
- with open(filepath, 'wb') as f:
35
- for chunk in response.iter_content(chunk_size=8192):
36
- f.write(chunk)
37
-
38
- if file_ext == "tar":
39
- print("Extracting the tar file...")
40
- with tarfile.open(filepath, 'r:*') as tar:
41
- tar.extractall(path=self.cache_dir)
42
- print("Extraction completed.")
43
- return self._find_safetensors_file(self.cache_dir)
44
-
45
- print(f"Download successfully, saved as {filepath}")
46
- except Exception as e:
47
- raise Exception(f"Error downloading {url}: {e}") from e
48
-
49
- else:
50
- print(f"LoRa already downloaded {url}")
51
-
52
- return filepath
53
-
54
- def _find_safetensors_file(self, directory):
55
- """
56
- Finds the first .safetensors file in the specified directory.
57
- """
58
- print("Searching for .safetensors file.")
59
- for root, dirs, files in os.walk(directory):
60
- for file in files:
61
- if file.endswith('.safetensors'):
62
- return os.path.join(root, file)
63
- raise FileNotFoundError("No .safetensors file found in the extracted files.")
64
-
65
- def check(self, urls):
66
- """Manages the specified LoRAs: downloads missing ones and returns their file names."""
67
- paths = []
68
- for url in urls:
69
- path = self._download_lora(url)
70
- paths.append(path)
71
- return paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/model_loader.py DELETED
@@ -1,46 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- """
4
- Download models from url
5
-
6
- @file: model_loader.py
7
- @author: Konie
8
- @update: 2024-03-22
9
- """
10
- from modules.model_loader import load_file_from_url
11
-
12
-
13
- def download_models():
14
- """
15
- Download models from config
16
- """
17
- vae_approx_filenames = [
18
- ('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
19
- ('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
20
- ('xl-to-v1_interposer-v3.1.safetensors', 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
21
- ]
22
-
23
- from modules.config import (
24
- paths_checkpoints as modelfile_path,
25
- paths_loras as lorafile_path,
26
- path_vae_approx as vae_approx_path,
27
- path_fooocus_expansion as fooocus_expansion_path,
28
- path_embeddings as embeddings_path,
29
- checkpoint_downloads,
30
- embeddings_downloads,
31
- lora_downloads)
32
-
33
- for file_name, url in checkpoint_downloads.items():
34
- load_file_from_url(url=url, model_dir=modelfile_path[0], file_name=file_name)
35
- for file_name, url in embeddings_downloads.items():
36
- load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name)
37
- for file_name, url in lora_downloads.items():
38
- load_file_from_url(url=url, model_dir=lorafile_path[0], file_name=file_name)
39
- for file_name, url in vae_approx_filenames:
40
- load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
41
-
42
- load_file_from_url(
43
- url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
44
- model_dir=fooocus_expansion_path,
45
- file_name='pytorch_model.bin'
46
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/utils/tools.py DELETED
@@ -1,159 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- """ Some tools
4
-
5
- @file: tools.py
6
- @author: Konie
7
- @update: 2024-03-22
8
- """
9
- # pylint: disable=line-too-long
10
- # pylint: disable=broad-exception-caught
11
- import os
12
- import sys
13
- import re
14
- import subprocess
15
- from importlib.util import find_spec
16
- from importlib import metadata
17
- from packaging import version
18
-
19
-
20
- PYTHON_EXEC = sys.executable
21
- INDEX_URL = os.environ.get('INDEX_URL', "")
22
- PATTERN = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
23
-
24
-
25
- # This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
26
- def run_command(command: str,
27
- desc: str = None,
28
- error_desc: str = None,
29
- custom_env: str = None,
30
- live: bool = True) -> str:
31
- """
32
- Run a command and return the output
33
- Args:
34
- command: Command to run
35
- desc: Description of the command
36
- error_desc: Description of the error
37
- custom_env: Custom environment variables
38
- live: Whether to print the output
39
- Returns:
40
- The output of the command
41
- """
42
- if desc is not None:
43
- print(desc)
44
-
45
- run_kwargs = {
46
- "args": command,
47
- "shell": True,
48
- "env": os.environ if custom_env is None else custom_env,
49
- "encoding": 'utf8',
50
- "errors": 'ignore'
51
- }
52
-
53
- if not live:
54
- run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
55
-
56
- result = subprocess.run(check=False, **run_kwargs)
57
-
58
- if result.returncode != 0:
59
- error_bits = [
60
- f"{error_desc or 'Error running command'}.",
61
- f"Command: {command}",
62
- f"Error code: {result.returncode}",
63
- ]
64
- if result.stdout:
65
- error_bits.append(f"stdout: {result.stdout}")
66
- if result.stderr:
67
- error_bits.append(f"stderr: {result.stderr}")
68
- raise RuntimeError("\n".join(error_bits))
69
-
70
- return result.stdout or ""
71
-
72
-
73
- # This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
74
- def run_pip(command, desc=None, live=True):
75
- """
76
- Run a pip command
77
- Args:
78
- command: Command to run
79
- desc: Description of the command
80
- live: Whether to print the output
81
- Returns:
82
- The output of the command
83
- """
84
- try:
85
- index_url_line = f' --index-url {INDEX_URL}' if INDEX_URL != '' else ''
86
- return run_command(
87
- command=f'"{PYTHON_EXEC}" -m pip {command} --prefer-binary{index_url_line}',
88
- desc=f"Installing {desc}",
89
- error_desc=f"Couldn't install {desc}",
90
- live=live
91
- )
92
- except Exception as e:
93
- print(f'CMD Failed {command}: {e}')
94
- return None
95
-
96
-
97
- def is_installed(package: str) -> bool:
98
- """
99
- Check if a package is installed
100
- Args:
101
- package: Package name
102
- Returns:
103
- Whether the package is installed
104
- """
105
- try:
106
- spec = find_spec(package)
107
- except ModuleNotFoundError:
108
- return False
109
-
110
- return spec is not None
111
-
112
-
113
- def check_torch_cuda() -> bool:
114
- """
115
- Check if torch and CUDA is available
116
- Returns:
117
- Whether CUDA is available
118
- """
119
- try:
120
- import torch
121
- return torch.cuda.is_available()
122
- except ImportError:
123
- return False
124
-
125
-
126
- def requirements_check(requirements_file: str = 'requirements.txt',
127
- pattern: re.Pattern = PATTERN) -> bool:
128
- """
129
- Check if the requirements file is satisfied
130
- Args:
131
- requirements_file: Path to the requirements file
132
- pattern: Pattern to match the requirements
133
- Returns:
134
- Whether the requirements file is satisfied
135
- """
136
- with open(requirements_file, "r", encoding="utf8") as file:
137
- for line in file:
138
- if line.strip() == "":
139
- continue
140
-
141
- m = re.match(pattern, line)
142
- if m is None:
143
- return False
144
-
145
- package = m.group(1).strip()
146
- version_required = (m.group(2) or "").strip()
147
-
148
- if version_required == "":
149
- continue
150
-
151
- try:
152
- version_installed = metadata.version(package)
153
- except Exception:
154
- return False
155
-
156
- if version.parse(version_required) != version.parse(version_installed):
157
- return False
158
-
159
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fooocusapi/worker.py DELETED
@@ -1,1044 +0,0 @@
1
- """
2
- Worker, modify from https://github.com/lllyasviel/Fooocus/blob/main/modules/async_worker.py
3
- """
4
- import copy
5
- import os
6
- import random
7
- import time
8
- from typing import List
9
- import logging
10
- import numpy as np
11
- import torch
12
-
13
- from fooocusapi.models.common.image_meta import image_parse
14
- from modules.patch import PatchSettings, patch_settings, patch_all
15
- from modules.flags import Performance
16
-
17
- from fooocusapi.utils.file_utils import save_output_file
18
- from fooocusapi.models.common.task import (
19
- GenerationFinishReason,
20
- ImageGenerationResult
21
- )
22
- from fooocusapi.utils.logger import logger
23
- from fooocusapi.task_queue import (
24
- QueueTask,
25
- TaskQueue,
26
- TaskOutputs
27
- )
28
-
29
- patch_all()
30
-
31
- worker_queue: TaskQueue | None = None
32
- last_model_name = None
33
-
34
-
35
- def process_stop():
36
- """Stop process"""
37
- import ldm_patched.modules.model_management
38
- ldm_patched.modules.model_management.interrupt_current_processing()
39
-
40
-
41
- @torch.no_grad()
42
- @torch.inference_mode()
43
- def task_schedule_loop():
44
- """Task schedule loop"""
45
- while True:
46
- if len(worker_queue.queue) == 0:
47
- time.sleep(0.05)
48
- continue
49
-
50
- current_task = worker_queue.queue[0]
51
- if current_task.start_mills == 0:
52
- process_generate(current_task)
53
-
54
-
55
- @torch.no_grad()
56
- @torch.inference_mode()
57
- def blocking_get_task_result(job_id: str) -> List[ImageGenerationResult]:
58
- """
59
- Get task result, when async_task is false
60
- :param job_id:
61
- :return:
62
- """
63
- waiting_sleep_steps: int = 0
64
- waiting_start_time = time.perf_counter()
65
- while not worker_queue.is_task_finished(job_id):
66
- if waiting_sleep_steps == 0:
67
- logger.std_info(f"[Task Queue] Waiting for task finished, job_id={job_id}")
68
- delay = 0.05
69
- time.sleep(delay)
70
- waiting_sleep_steps += 1
71
- if waiting_sleep_steps % int(10 / delay) == 0:
72
- waiting_time = time.perf_counter() - waiting_start_time
73
- logger.std_info(f"[Task Queue] Already waiting for {round(waiting_time, 1)} seconds, job_id={job_id}")
74
-
75
- task = worker_queue.get_task(job_id, True)
76
- return task.task_result
77
-
78
-
79
- @torch.no_grad()
80
- @torch.inference_mode()
81
- def process_generate(async_task: QueueTask):
82
- """Generate image"""
83
- try:
84
- import modules.default_pipeline as pipeline
85
- except Exception as e:
86
- logger.std_error(f'[Task Queue] Import default pipeline error: {e}')
87
- if not async_task.is_finished:
88
- worker_queue.finish_task(async_task.job_id)
89
- async_task.set_result([], True, str(e))
90
- logger.std_error(f"[Task Queue] Finish task with error, seq={async_task.job_id}")
91
- return []
92
-
93
- import modules.flags as flags
94
- import modules.core as core
95
- import modules.inpaint_worker as inpaint_worker
96
- import modules.config as config
97
- import modules.constants as constants
98
- import extras.preprocessors as preprocessors
99
- import extras.ip_adapter as ip_adapter
100
- import extras.face_crop as face_crop
101
- import ldm_patched.modules.model_management as model_management
102
- from modules.util import (
103
- remove_empty_str, HWC3, resize_image,
104
- get_image_shape_ceil, set_image_shape_ceil,
105
- get_shape_ceil, resample_image, erode_or_dilate,
106
- get_enabled_loras, parse_lora_references_from_prompt, apply_wildcards,
107
- remove_performance_lora
108
- )
109
-
110
- from modules.upscaler import perform_upscale
111
- from extras.expansion import safe_str
112
- from extras.censor import default_censor
113
- from modules.sdxl_styles import (
114
- apply_style, get_random_style,
115
- fooocus_expansion, apply_arrays, random_style_name
116
- )
117
-
118
- pid = os.getpid()
119
-
120
- outputs = TaskOutputs(async_task)
121
- results = []
122
-
123
- def refresh_seed(seed_string: int | str | None) -> int:
124
- """
125
- Refresh and check seed number.
126
- :params seed_string: seed, str or int. None means random
127
- :return: seed number
128
- """
129
- if seed_string is None or seed_string == -1:
130
- return random.randint(constants.MIN_SEED, constants.MAX_SEED)
131
-
132
- try:
133
- seed_value = int(seed_string)
134
- if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
135
- return seed_value
136
- except ValueError:
137
- pass
138
- return random.randint(constants.MIN_SEED, constants.MAX_SEED)
139
-
140
- def progressbar(_, number, text):
141
- """progress bar"""
142
- logger.std_info(f'[Fooocus] {text}')
143
- outputs.append(['preview', (number, text, None)])
144
-
145
- def yield_result(_, images, tasks, extension='png',
146
- blockout_nsfw=False, censor=True):
147
- """
148
- Yield result
149
- :param _: async task object
150
- :param images: list for generated image
151
- :param tasks: the image was generated one by one, when image number is not one, it will be a task list
152
- :param extension: extension for saved image
153
- :param blockout_nsfw: blockout nsfw image
154
- :param censor: censor image
155
- :return:
156
- """
157
- if not isinstance(images, list):
158
- images = [images]
159
-
160
- if censor and (config.default_black_out_nsfw or black_out_nsfw):
161
- images = default_censor(images)
162
-
163
- results = []
164
- for index, im in enumerate(images):
165
- if async_task.req_param.save_name == '':
166
- image_name = f"{async_task.job_id}-{str(index)}"
167
- else:
168
- image_name = f"{async_task.req_param.save_name}-{str(index)}"
169
- if len(tasks) == 0:
170
- img_seed = -1
171
- img_meta = {}
172
- else:
173
- img_seed = tasks[index]['task_seed']
174
- img_meta = image_parse(
175
- async_tak=async_task,
176
- task=tasks[index])
177
- img_filename = save_output_file(
178
- img=im,
179
- image_name=image_name,
180
- image_meta=img_meta,
181
- extension=extension)
182
- results.append(ImageGenerationResult(
183
- im=img_filename,
184
- seed=str(img_seed),
185
- finish_reason=GenerationFinishReason.success))
186
- async_task.set_result(results, False)
187
- worker_queue.finish_task(async_task.job_id)
188
- logger.std_info(f"[Task Queue] Finish task, job_id={async_task.job_id}")
189
-
190
- outputs.append(['results', images])
191
- pipeline.prepare_text_encoder(async_call=True)
192
-
193
- try:
194
- logger.std_info(f"[Task Queue] Task queue start task, job_id={async_task.job_id}")
195
- # clear memory
196
- global last_model_name
197
-
198
- if last_model_name is None:
199
- last_model_name = async_task.req_param.base_model_name
200
- if last_model_name != async_task.req_param.base_model_name:
201
- model_management.cleanup_models() # key1
202
- model_management.unload_all_models()
203
- model_management.soft_empty_cache() # key2
204
- last_model_name = async_task.req_param.base_model_name
205
-
206
- worker_queue.start_task(async_task.job_id)
207
-
208
- execution_start_time = time.perf_counter()
209
-
210
- # Transform parameters
211
- params = async_task.req_param
212
- prompt = params.prompt
213
- negative_prompt = params.negative_prompt
214
- style_selections = params.style_selections
215
- performance_selection = Performance(params.performance_selection)
216
- aspect_ratios_selection = params.aspect_ratios_selection
217
- image_number = params.image_number
218
- save_metadata_to_images = params.save_meta
219
- metadata_scheme = params.meta_scheme
220
- save_extension = params.save_extension
221
- save_name = params.save_name
222
- image_seed = refresh_seed(params.image_seed)
223
- read_wildcards_in_order = False
224
- sharpness = params.sharpness
225
- guidance_scale = params.guidance_scale
226
- base_model_name = params.base_model_name
227
- refiner_model_name = params.refiner_model_name
228
- refiner_switch = params.refiner_switch
229
- loras = params.loras
230
- input_image_checkbox = params.uov_input_image is not None or params.inpaint_input_image is not None or len(params.image_prompts) > 0
231
- current_tab = 'uov' if params.uov_method != flags.disabled else 'ip' if len(params.image_prompts) > 0 else 'inpaint' if params.inpaint_input_image is not None else None
232
- uov_method = params.uov_method
233
- upscale_value = params.upscale_value
234
- uov_input_image = params.uov_input_image
235
- outpaint_selections = params.outpaint_selections
236
- outpaint_distance_left = params.outpaint_distance_left
237
- outpaint_distance_top = params.outpaint_distance_top
238
- outpaint_distance_right = params.outpaint_distance_right
239
- outpaint_distance_bottom = params.outpaint_distance_bottom
240
- inpaint_input_image = params.inpaint_input_image
241
- inpaint_additional_prompt = '' if params.inpaint_additional_prompt is None else params.inpaint_additional_prompt
242
- inpaint_mask_image_upload = None
243
-
244
- adp = params.advanced_params
245
- disable_preview = adp.disable_preview
246
- disable_intermediate_results = adp.disable_intermediate_results
247
- disable_seed_increment = adp.disable_seed_increment
248
- adm_scaler_positive = adp.adm_scaler_positive
249
- adm_scaler_negative = adp.adm_scaler_negative
250
- adm_scaler_end = adp.adm_scaler_end
251
- adaptive_cfg = adp.adaptive_cfg
252
- sampler_name = adp.sampler_name
253
- scheduler_name = adp.scheduler_name
254
- overwrite_step = adp.overwrite_step
255
- overwrite_switch = adp.overwrite_switch
256
- overwrite_width = adp.overwrite_width
257
- overwrite_height = adp.overwrite_height
258
- overwrite_vary_strength = adp.overwrite_vary_strength
259
- overwrite_upscale_strength = adp.overwrite_upscale_strength
260
- mixing_image_prompt_and_vary_upscale = adp.mixing_image_prompt_and_vary_upscale
261
- mixing_image_prompt_and_inpaint = adp.mixing_image_prompt_and_inpaint
262
- debugging_cn_preprocessor = adp.debugging_cn_preprocessor
263
- skipping_cn_preprocessor = adp.skipping_cn_preprocessor
264
- canny_low_threshold = adp.canny_low_threshold
265
- canny_high_threshold = adp.canny_high_threshold
266
- refiner_swap_method = adp.refiner_swap_method
267
- controlnet_softness = adp.controlnet_softness
268
- freeu_enabled = adp.freeu_enabled
269
- freeu_b1 = adp.freeu_b1
270
- freeu_b2 = adp.freeu_b2
271
- freeu_s1 = adp.freeu_s1
272
- freeu_s2 = adp.freeu_s2
273
- debugging_inpaint_preprocessor = adp.debugging_inpaint_preprocessor
274
- inpaint_disable_initial_latent = adp.inpaint_disable_initial_latent
275
- inpaint_engine = adp.inpaint_engine
276
- inpaint_strength = adp.inpaint_strength
277
- inpaint_respective_field = adp.inpaint_respective_field
278
- inpaint_mask_upload_checkbox = adp.inpaint_mask_upload_checkbox
279
- invert_mask_checkbox = adp.invert_mask_checkbox
280
- inpaint_erode_or_dilate = adp.inpaint_erode_or_dilate
281
- black_out_nsfw = adp.black_out_nsfw
282
- vae_name = adp.vae_name
283
- clip_skip = adp.clip_skip
284
-
285
- cn_tasks = {x: [] for x in flags.ip_list}
286
- for img_prompt in params.image_prompts:
287
- cn_img, cn_stop, cn_weight, cn_type = img_prompt
288
- cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
289
-
290
- if inpaint_input_image is not None and inpaint_input_image['image'] is not None:
291
- inpaint_image_size = inpaint_input_image['image'].shape[:2]
292
- if inpaint_input_image['mask'] is None:
293
- inpaint_input_image['mask'] = np.zeros(inpaint_image_size, dtype=np.uint8)
294
- else:
295
- inpaint_mask_upload_checkbox = True
296
-
297
- inpaint_input_image['mask'] = HWC3(inpaint_input_image['mask'])
298
- inpaint_mask_image_upload = inpaint_input_image['mask']
299
-
300
- # Fooocus async_worker.py code start
301
-
302
- outpaint_selections = [o.lower() for o in outpaint_selections]
303
- base_model_additional_loras = []
304
- raw_style_selections = copy.deepcopy(style_selections)
305
- uov_method = uov_method.lower()
306
-
307
- if fooocus_expansion in style_selections:
308
- use_expansion = True
309
- style_selections.remove(fooocus_expansion)
310
- else:
311
- use_expansion = False
312
-
313
- use_style = len(style_selections) > 0
314
-
315
- if base_model_name == refiner_model_name:
316
- logger.std_warn('[Fooocus] Refiner disabled because base model and refiner are same.')
317
- refiner_model_name = 'None'
318
-
319
- steps = performance_selection.steps()
320
-
321
- performance_loras = []
322
-
323
- if performance_selection == Performance.EXTREME_SPEED:
324
- logger.std_warn('[Fooocus] Enter LCM mode.')
325
- progressbar(async_task, 1, 'Downloading LCM components ...')
326
- performance_loras += [(config.downloading_sdxl_lcm_lora(), 1.0)]
327
-
328
- if refiner_model_name != 'None':
329
- logger.std_info('[Fooocus] Refiner disabled in LCM mode.')
330
-
331
- refiner_model_name = 'None'
332
- sampler_name = 'lcm'
333
- scheduler_name = 'lcm'
334
- sharpness = 0.0
335
- guidance_scale = 1.0
336
- adaptive_cfg = 1.0
337
- refiner_switch = 1.0
338
- adm_scaler_positive = 1.0
339
- adm_scaler_negative = 1.0
340
- adm_scaler_end = 0.0
341
-
342
- elif performance_selection == Performance.LIGHTNING:
343
- logger.std_info('[Fooocus] Enter Lightning mode.')
344
- progressbar(async_task, 1, 'Downloading Lightning components ...')
345
- performance_loras += [(config.downloading_sdxl_lightning_lora(), 1.0)]
346
-
347
- if refiner_model_name != 'None':
348
- logger.std_info('[Fooocus] Refiner disabled in Lightning mode.')
349
-
350
- refiner_model_name = 'None'
351
- sampler_name = 'euler'
352
- scheduler_name = 'sgm_uniform'
353
- sharpness = 0.0
354
- guidance_scale = 1.0
355
- adaptive_cfg = 1.0
356
- refiner_switch = 1.0
357
- adm_scaler_positive = 1.0
358
- adm_scaler_negative = 1.0
359
- adm_scaler_end = 0.0
360
-
361
- elif performance_selection == Performance.HYPER_SD:
362
- print('Enter Hyper-SD mode.')
363
- progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
364
- performance_loras += [(config.downloading_sdxl_hyper_sd_lora(), 0.8)]
365
-
366
- if refiner_model_name != 'None':
367
- logger.std_info('[Fooocus] Refiner disabled in Hyper-SD mode.')
368
-
369
- refiner_model_name = 'None'
370
- sampler_name = 'dpmpp_sde_gpu'
371
- scheduler_name = 'karras'
372
- sharpness = 0.0
373
- guidance_scale = 1.0
374
- adaptive_cfg = 1.0
375
- refiner_switch = 1.0
376
- adm_scaler_positive = 1.0
377
- adm_scaler_negative = 1.0
378
- adm_scaler_end = 0.0
379
-
380
- logger.std_info(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
381
- logger.std_info(f'[Parameters] CLIP Skip = {clip_skip}')
382
- logger.std_info(f'[Parameters] Sharpness = {sharpness}')
383
- logger.std_info(f'[Parameters] ControlNet Softness = {controlnet_softness}')
384
- logger.std_info(f'[Parameters] ADM Scale = '
385
- f'{adm_scaler_positive} : '
386
- f'{adm_scaler_negative} : '
387
- f'{adm_scaler_end}')
388
-
389
- patch_settings[pid] = PatchSettings(
390
- sharpness,
391
- adm_scaler_end,
392
- adm_scaler_positive,
393
- adm_scaler_negative,
394
- controlnet_softness,
395
- adaptive_cfg
396
- )
397
-
398
- cfg_scale = float(guidance_scale)
399
- logger.std_info(f'[Parameters] CFG = {cfg_scale}')
400
-
401
- initial_latent = None
402
- denoising_strength = 1.0
403
- tiled = False
404
-
405
- width, height = aspect_ratios_selection.replace('×', ' ').replace('*', ' ').split(' ')[:2]
406
- width, height = int(width), int(height)
407
-
408
- skip_prompt_processing = False
409
-
410
- inpaint_worker.current_task = None
411
- inpaint_parameterized = inpaint_engine != 'None'
412
- inpaint_image = None
413
- inpaint_mask = None
414
- inpaint_head_model_path = None
415
-
416
- use_synthetic_refiner = False
417
-
418
- controlnet_canny_path = None
419
- controlnet_cpds_path = None
420
- clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
421
-
422
- seed = int(image_seed)
423
- logger.std_info(f'[Parameters] Seed = {seed}')
424
-
425
- goals = []
426
- tasks = []
427
-
428
- if input_image_checkbox:
429
- if (current_tab == 'uov' or (
430
- current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
431
- and uov_method != flags.disabled and uov_input_image is not None:
432
- uov_input_image = HWC3(uov_input_image)
433
- if 'vary' in uov_method:
434
- goals.append('vary')
435
- elif 'upscale' in uov_method:
436
- goals.append('upscale')
437
- if 'fast' in uov_method:
438
- skip_prompt_processing = True
439
- else:
440
- steps = performance_selection.steps_uov()
441
-
442
- progressbar(async_task, 1, 'Downloading upscale models ...')
443
- config.downloading_upscale_model()
444
- if (current_tab == 'inpaint' or (
445
- current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
446
- and isinstance(inpaint_input_image, dict):
447
- inpaint_image = inpaint_input_image['image']
448
- inpaint_mask = inpaint_input_image['mask'][:, :, 0]
449
-
450
- if inpaint_mask_upload_checkbox:
451
- if isinstance(inpaint_mask_image_upload, np.ndarray):
452
- if inpaint_mask_image_upload.ndim == 3:
453
- H, W, C = inpaint_image.shape
454
- inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H)
455
- inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2)
456
- inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
457
- inpaint_mask = np.maximum(np.zeros(shape=(H, W), dtype=np.uint8), inpaint_mask_image_upload)
458
-
459
- if int(inpaint_erode_or_dilate) != 0:
460
- inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
461
-
462
- if invert_mask_checkbox:
463
- inpaint_mask = 255 - inpaint_mask
464
-
465
- inpaint_image = HWC3(inpaint_image)
466
- if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
467
- and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
468
- progressbar(async_task, 1, 'Downloading upscale models ...')
469
- config.downloading_upscale_model()
470
- if inpaint_parameterized:
471
- progressbar(async_task, 1, 'Downloading inpainter ...')
472
- inpaint_head_model_path, inpaint_patch_model_path = config.downloading_inpaint_models(
473
- inpaint_engine)
474
- base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
475
- logger.std_info(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
476
- if refiner_model_name == 'None':
477
- use_synthetic_refiner = True
478
- refiner_switch = 0.8
479
- else:
480
- inpaint_head_model_path, inpaint_patch_model_path = None, None
481
- logger.std_info('[Inpaint] Parameterized inpaint is disabled.')
482
- if inpaint_additional_prompt != '':
483
- if prompt == '':
484
- prompt = inpaint_additional_prompt
485
- else:
486
- prompt = inpaint_additional_prompt + '\n' + prompt
487
- goals.append('inpaint')
488
- if current_tab == 'ip' or \
489
- mixing_image_prompt_and_vary_upscale or \
490
- mixing_image_prompt_and_inpaint:
491
- goals.append('cn')
492
- progressbar(async_task, 1, 'Downloading control models ...')
493
- if len(cn_tasks[flags.cn_canny]) > 0:
494
- controlnet_canny_path = config.downloading_controlnet_canny()
495
- if len(cn_tasks[flags.cn_cpds]) > 0:
496
- controlnet_cpds_path = config.downloading_controlnet_cpds()
497
- if len(cn_tasks[flags.cn_ip]) > 0:
498
- clip_vision_path, ip_negative_path, ip_adapter_path = config.downloading_ip_adapters('ip')
499
- if len(cn_tasks[flags.cn_ip_face]) > 0:
500
- clip_vision_path, ip_negative_path, ip_adapter_face_path = config.downloading_ip_adapters(
501
- 'face')
502
- progressbar(async_task, 1, 'Loading control models ...')
503
-
504
- # Load or unload CNs
505
- pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
506
- ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
507
- ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
508
-
509
- if overwrite_step > 0:
510
- steps = overwrite_step
511
-
512
- switch = int(round(steps * refiner_switch))
513
-
514
- if overwrite_switch > 0:
515
- switch = overwrite_switch
516
-
517
- if overwrite_width > 0:
518
- width = overwrite_width
519
-
520
- if overwrite_height > 0:
521
- height = overwrite_height
522
-
523
- logger.std_info(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
524
- logger.std_info(f'[Parameters] Steps = {steps} - {switch}')
525
-
526
- progressbar(async_task, 1, 'Initializing ...')
527
-
528
- if not skip_prompt_processing:
529
-
530
- prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
531
- negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
532
-
533
- prompt = prompts[0]
534
- negative_prompt = negative_prompts[0]
535
-
536
- if prompt == '':
537
- # disable expansion when empty since it is not meaningful and influences image prompt
538
- use_expansion = False
539
-
540
- extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
541
- extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
542
-
543
- progressbar(async_task, 3, 'Loading models ...')
544
- lora_filenames = remove_performance_lora(config.lora_filenames, performance_selection)
545
- loras, prompt = parse_lora_references_from_prompt(prompt, loras, config.default_max_lora_number, lora_filenames=lora_filenames)
546
- loras += performance_loras
547
-
548
- pipeline.refresh_everything(
549
- refiner_model_name=refiner_model_name,
550
- base_model_name=base_model_name,
551
- loras=loras,
552
- base_model_additional_loras=base_model_additional_loras,
553
- use_synthetic_refiner=use_synthetic_refiner)
554
-
555
- pipeline.set_clip_skip(clip_skip)
556
-
557
- progressbar(async_task, 3, 'Processing prompts ...')
558
- tasks = []
559
-
560
- for i in range(image_number):
561
- if disable_seed_increment:
562
- task_seed = seed % (constants.MAX_SEED + 1)
563
- else:
564
- task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
565
-
566
- task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
567
- task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
568
- task_prompt = apply_arrays(task_prompt, i)
569
- task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
570
- task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
571
- extra_positive_prompts]
572
- task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
573
- extra_negative_prompts]
574
-
575
- positive_basic_workloads = []
576
- negative_basic_workloads = []
577
-
578
- task_styles = style_selections.copy()
579
- if use_style:
580
- for index, style in enumerate(task_styles):
581
- if style == random_style_name:
582
- style = get_random_style(task_rng)
583
- task_styles[index] = style
584
- p, n = apply_style(style, positive=task_prompt)
585
- positive_basic_workloads = positive_basic_workloads + p
586
- negative_basic_workloads = negative_basic_workloads + n
587
- else:
588
- positive_basic_workloads.append(task_prompt)
589
-
590
- negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative.
591
-
592
- positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts
593
- negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts
594
-
595
- positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt)
596
- negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt)
597
-
598
- tasks.append(dict(
599
- task_seed=task_seed,
600
- task_prompt=task_prompt,
601
- task_negative_prompt=task_negative_prompt,
602
- positive=positive_basic_workloads,
603
- negative=negative_basic_workloads,
604
- expansion='',
605
- c=None,
606
- uc=None,
607
- positive_top_k=len(positive_basic_workloads),
608
- negative_top_k=len(negative_basic_workloads),
609
- log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
610
- log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
611
- styles=task_styles
612
- ))
613
-
614
- if use_expansion:
615
- for i, t in enumerate(tasks):
616
- progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...')
617
- expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
618
- logger.std_info(f'[Prompt Expansion] {expansion}')
619
- t['expansion'] = expansion
620
- t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
621
-
622
- for i, t in enumerate(tasks):
623
- progressbar(async_task, 5, f'Encoding positive #{i + 1} ...')
624
- t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
625
-
626
- for i, t in enumerate(tasks):
627
- if abs(float(cfg_scale) - 1.0) < 1e-4:
628
- t['uc'] = pipeline.clone_cond(t['c'])
629
- else:
630
- progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
631
- t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
632
-
633
- if len(goals) > 0:
634
- progressbar(async_task, 7, 'Image processing ...')
635
-
636
- if 'vary' in goals:
637
- if 'subtle' in uov_method:
638
- denoising_strength = 0.5
639
- if 'strong' in uov_method:
640
- denoising_strength = 0.85
641
- if overwrite_vary_strength > 0:
642
- denoising_strength = overwrite_vary_strength
643
-
644
- shape_ceil = get_image_shape_ceil(uov_input_image)
645
- if shape_ceil < 1024:
646
- logger.std_warn('[Vary] Image is resized because it is too small.')
647
- shape_ceil = 1024
648
- elif shape_ceil > 2048:
649
- logger.std_warn('[Vary] Image is resized because it is too big.')
650
- shape_ceil = 2048
651
-
652
- uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
653
-
654
- initial_pixels = core.numpy_to_pytorch(uov_input_image)
655
- progressbar(async_task, 8, 'VAE encoding ...')
656
-
657
- candidate_vae, _ = pipeline.get_candidate_vae(
658
- steps=steps,
659
- switch=switch,
660
- denoise=denoising_strength,
661
- refiner_swap_method=refiner_swap_method
662
- )
663
-
664
- initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels)
665
- B, C, H, W = initial_latent['samples'].shape
666
- width = W * 8
667
- height = H * 8
668
- logger.std_info(f'[Vary] Final resolution is {str((height, width))}.')
669
-
670
- if 'upscale' in goals:
671
- H, W, C = uov_input_image.shape
672
- progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...')
673
- uov_input_image = perform_upscale(uov_input_image)
674
- logger.std_info('[Upscale] Image upscale.')
675
-
676
- if upscale_value is not None and upscale_value > 1.0:
677
- f = upscale_value
678
- else:
679
- if '1.5x' in uov_method:
680
- f = 1.5
681
- elif '2x' in uov_method:
682
- f = 2.0
683
- else:
684
- f = 1.0
685
-
686
- shape_ceil = get_shape_ceil(H * f, W * f)
687
-
688
- if shape_ceil < 1024:
689
- logger.std_info('[Upscale] Image is resized because it is too small.')
690
- uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
691
- shape_ceil = 1024
692
- else:
693
- uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
694
-
695
- image_is_super_large = shape_ceil > 2800
696
-
697
- if 'fast' in uov_method:
698
- direct_return = True
699
- elif image_is_super_large:
700
- logger.std_info('[Upscale] Image is too large. Directly returned the SR image. '
701
- 'Usually directly return SR image at 4K resolution '
702
- 'yields better results than SDXL diffusion.')
703
- direct_return = True
704
- else:
705
- direct_return = False
706
-
707
- if direct_return:
708
- # d = [('Upscale (Fast)', '2x')]
709
- # log(uov_input_image, d, output_format=save_extension)
710
- if config.default_black_out_nsfw or black_out_nsfw:
711
- uov_input_image = default_censor(uov_input_image)
712
- yield_result(async_task, uov_input_image, tasks, save_extension, False, False)
713
- return
714
-
715
- tiled = True
716
- denoising_strength = 0.382
717
-
718
- if overwrite_upscale_strength > 0:
719
- denoising_strength = overwrite_upscale_strength
720
-
721
- initial_pixels = core.numpy_to_pytorch(uov_input_image)
722
- progressbar(async_task, 10, 'VAE encoding ...')
723
-
724
- candidate_vae, _ = pipeline.get_candidate_vae(
725
- steps=steps,
726
- switch=switch,
727
- denoise=denoising_strength,
728
- refiner_swap_method=refiner_swap_method
729
- )
730
-
731
- initial_latent = core.encode_vae(
732
- vae=candidate_vae,
733
- pixels=initial_pixels, tiled=True)
734
- B, C, H, W = initial_latent['samples'].shape
735
- width = W * 8
736
- height = H * 8
737
- logger.std_info(f'[Upscale] Final resolution is {str((height, width))}.')
738
-
739
- if 'inpaint' in goals:
740
- if len(outpaint_selections) > 0:
741
- H, W, C = inpaint_image.shape
742
- if 'top' in outpaint_selections:
743
- distance_top = int(H * 0.3)
744
- if outpaint_distance_top > 0:
745
- distance_top = outpaint_distance_top
746
-
747
- inpaint_image = np.pad(inpaint_image, [[distance_top, 0], [0, 0], [0, 0]], mode='edge')
748
- inpaint_mask = np.pad(inpaint_mask, [[distance_top, 0], [0, 0]], mode='constant',
749
- constant_values=255)
750
-
751
- if 'bottom' in outpaint_selections:
752
- distance_bottom = int(H * 0.3)
753
- if outpaint_distance_bottom > 0:
754
- distance_bottom = outpaint_distance_bottom
755
-
756
- inpaint_image = np.pad(inpaint_image, [[0, distance_bottom], [0, 0], [0, 0]], mode='edge')
757
- inpaint_mask = np.pad(inpaint_mask, [[0, distance_bottom], [0, 0]], mode='constant',
758
- constant_values=255)
759
-
760
- H, W, C = inpaint_image.shape
761
- if 'left' in outpaint_selections:
762
- distance_left = int(W * 0.3)
763
- if outpaint_distance_left > 0:
764
- distance_left = outpaint_distance_left
765
-
766
- inpaint_image = np.pad(inpaint_image, [[0, 0], [distance_left, 0], [0, 0]], mode='edge')
767
- inpaint_mask = np.pad(inpaint_mask, [[0, 0], [distance_left, 0]], mode='constant',
768
- constant_values=255)
769
-
770
- if 'right' in outpaint_selections:
771
- distance_right = int(W * 0.3)
772
- if outpaint_distance_right > 0:
773
- distance_right = outpaint_distance_right
774
-
775
- inpaint_image = np.pad(inpaint_image, [[0, 0], [0, distance_right], [0, 0]], mode='edge')
776
- inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, distance_right]], mode='constant',
777
- constant_values=255)
778
-
779
- inpaint_image = np.ascontiguousarray(inpaint_image.copy())
780
- inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
781
- inpaint_strength = 1.0
782
- inpaint_respective_field = 1.0
783
-
784
- denoising_strength = inpaint_strength
785
-
786
- inpaint_worker.current_task = inpaint_worker.InpaintWorker(
787
- image=inpaint_image,
788
- mask=inpaint_mask,
789
- use_fill=denoising_strength > 0.99,
790
- k=inpaint_respective_field
791
- )
792
-
793
- if debugging_inpaint_preprocessor:
794
- yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), tasks,
795
- black_out_nsfw)
796
- return
797
-
798
- progressbar(async_task, 11, 'VAE Inpaint encoding ...')
799
-
800
- inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
801
- inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
802
- inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
803
-
804
- candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae(
805
- steps=steps,
806
- switch=switch,
807
- denoise=denoising_strength,
808
- refiner_swap_method=refiner_swap_method
809
- )
810
-
811
- latent_inpaint, latent_mask = core.encode_vae_inpaint(
812
- mask=inpaint_pixel_mask,
813
- vae=candidate_vae,
814
- pixels=inpaint_pixel_image)
815
-
816
- latent_swap = None
817
- if candidate_vae_swap is not None:
818
- progressbar(async_task, 12, 'VAE SD15 encoding ...')
819
- latent_swap = core.encode_vae(
820
- vae=candidate_vae_swap,
821
- pixels=inpaint_pixel_fill)['samples']
822
-
823
- progressbar(async_task, 13, 'VAE encoding ...')
824
- latent_fill = core.encode_vae(
825
- vae=candidate_vae,
826
- pixels=inpaint_pixel_fill)['samples']
827
-
828
- inpaint_worker.current_task.load_latent(
829
- latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap)
830
-
831
- if inpaint_parameterized:
832
- pipeline.final_unet = inpaint_worker.current_task.patch(
833
- inpaint_head_model_path=inpaint_head_model_path,
834
- inpaint_latent=latent_inpaint,
835
- inpaint_latent_mask=latent_mask,
836
- model=pipeline.final_unet
837
- )
838
-
839
- if not inpaint_disable_initial_latent:
840
- initial_latent = {'samples': latent_fill}
841
-
842
- B, C, H, W = latent_fill.shape
843
- height, width = H * 8, W * 8
844
- final_height, final_width = inpaint_worker.current_task.image.shape[:2]
845
- logger.std_info(f'[Inpaint] Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
846
-
847
- if 'cn' in goals:
848
- for task in cn_tasks[flags.cn_canny]:
849
- cn_img, cn_stop, cn_weight = task
850
- cn_img = resize_image(HWC3(cn_img), width=width, height=height)
851
-
852
- if not skipping_cn_preprocessor:
853
- cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
854
-
855
- cn_img = HWC3(cn_img)
856
- task[0] = core.numpy_to_pytorch(cn_img)
857
- if debugging_cn_preprocessor:
858
- yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
859
- return
860
- for task in cn_tasks[flags.cn_cpds]:
861
- cn_img, cn_stop, cn_weight = task
862
- cn_img = resize_image(HWC3(cn_img), width=width, height=height)
863
-
864
- if not skipping_cn_preprocessor:
865
- cn_img = preprocessors.cpds(cn_img)
866
-
867
- cn_img = HWC3(cn_img)
868
- task[0] = core.numpy_to_pytorch(cn_img)
869
- if debugging_cn_preprocessor:
870
- yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
871
- return
872
- for task in cn_tasks[flags.cn_ip]:
873
- cn_img, cn_stop, cn_weight = task
874
- cn_img = HWC3(cn_img)
875
-
876
- # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
877
- cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
878
-
879
- task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
880
- if debugging_cn_preprocessor:
881
- yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
882
- return
883
- for task in cn_tasks[flags.cn_ip_face]:
884
- cn_img, cn_stop, cn_weight = task
885
- cn_img = HWC3(cn_img)
886
-
887
- if not skipping_cn_preprocessor:
888
- cn_img = face_crop.crop_image(cn_img)
889
-
890
- # https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
891
- cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
892
-
893
- task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
894
- if debugging_cn_preprocessor:
895
- yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
896
- return
897
-
898
- all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
899
-
900
- if len(all_ip_tasks) > 0:
901
- pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
902
-
903
- if freeu_enabled:
904
- logger.std_info('[Fooocus] FreeU is enabled!')
905
- pipeline.final_unet = core.apply_freeu(
906
- pipeline.final_unet,
907
- freeu_b1,
908
- freeu_b2,
909
- freeu_s1,
910
- freeu_s2
911
- )
912
-
913
- all_steps = steps * image_number
914
-
915
- logger.std_info(f'[Parameters] Denoising Strength = {denoising_strength}')
916
-
917
- if isinstance(initial_latent, dict) and 'samples' in initial_latent:
918
- log_shape = initial_latent['samples'].shape
919
- else:
920
- log_shape = f'Image Space {(height, width)}'
921
-
922
- logger.std_info(f'[Parameters] Initial Latent shape: {log_shape}')
923
-
924
- preparation_time = time.perf_counter() - execution_start_time
925
- logger.std_info(f'[Fooocus] Preparation time: {preparation_time:.2f} seconds')
926
-
927
- final_sampler_name = sampler_name
928
- final_scheduler_name = scheduler_name
929
-
930
- if scheduler_name in ['lcm', 'tcd']:
931
- final_scheduler_name = 'sgm_uniform'
932
-
933
- def patch_discrete(unet):
934
- return core.opModelSamplingDiscrete.patch(
935
- pipeline.final_unet,
936
- sampling=scheduler_name,
937
- zsnr=False)[0]
938
-
939
- if pipeline.final_unet is not None:
940
- pipeline.final_unet = patch_discrete(pipeline.final_unet)
941
- if pipeline.final_refiner_unet is not None:
942
- pipeline.final_refiner_unet = patch_discrete(pipeline.final_refiner_unet)
943
- logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
944
- elif scheduler_name == 'edm_playground_v2.5':
945
- final_scheduler_name = 'karras'
946
-
947
- def patch_edm(unet):
948
- return core.opModelSamplingContinuousEDM.patch(
949
- unet,
950
- sampling=scheduler_name,
951
- sigma_max=120.0,
952
- sigma_min=0.002)[0]
953
-
954
- if pipeline.final_unet is not None:
955
- pipeline.final_unet = patch_edm(pipeline.final_unet)
956
- if pipeline.final_refiner_unet is not None:
957
- pipeline.final_refiner_unet = patch_edm(pipeline.final_refiner_unet)
958
-
959
- logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
960
-
961
- outputs.append(['preview', (13, 'Moving model to GPU ...', None)])
962
-
963
- def callback(step, x0, x, total_steps, y):
964
- """callback, used for progress and preview"""
965
- done_steps = current_task_id * steps + step
966
- outputs.append(['preview', (
967
- int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
968
- f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
969
- y)])
970
-
971
- for current_task_id, task in enumerate(tasks):
972
- execution_start_time = time.perf_counter()
973
-
974
- try:
975
- positive_cond, negative_cond = task['c'], task['uc']
976
-
977
- if 'cn' in goals:
978
- for cn_flag, cn_path in [
979
- (flags.cn_canny, controlnet_canny_path),
980
- (flags.cn_cpds, controlnet_cpds_path)
981
- ]:
982
- for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]:
983
- positive_cond, negative_cond = core.apply_controlnet(
984
- positive_cond, negative_cond,
985
- pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop)
986
-
987
- imgs = pipeline.process_diffusion(
988
- positive_cond=positive_cond,
989
- negative_cond=negative_cond,
990
- steps=steps,
991
- switch=switch,
992
- width=width,
993
- height=height,
994
- image_seed=task['task_seed'],
995
- callback=callback,
996
- sampler_name=final_sampler_name,
997
- scheduler_name=final_scheduler_name,
998
- latent=initial_latent,
999
- denoise=denoising_strength,
1000
- tiled=tiled,
1001
- cfg_scale=cfg_scale,
1002
- refiner_swap_method=refiner_swap_method,
1003
- disable_preview=disable_preview
1004
- )
1005
-
1006
- del task['c'], task['uc'], positive_cond, negative_cond # Save memory
1007
-
1008
- if inpaint_worker.current_task is not None:
1009
- imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
1010
-
1011
- # Fooocus async_worker.py code end
1012
-
1013
- results += imgs
1014
- except model_management.InterruptProcessingException as e:
1015
- logger.std_warn("[Fooocus] User stopped")
1016
- results = []
1017
- results.append(ImageGenerationResult(
1018
- im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.user_cancel))
1019
- async_task.set_result(results, True, str(e))
1020
- break
1021
- except Exception as e:
1022
- logger.std_error(f'[Fooocus] Process error: {e}')
1023
- logging.exception(e)
1024
- results = []
1025
- results.append(ImageGenerationResult(
1026
- im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.error))
1027
- async_task.set_result(results, True, str(e))
1028
- break
1029
-
1030
- execution_time = time.perf_counter() - execution_start_time
1031
- logger.std_info(f'[Fooocus] Generating and saving time: {execution_time:.2f} seconds')
1032
-
1033
- if async_task.finish_with_error:
1034
- worker_queue.finish_task(async_task.job_id)
1035
- return async_task.task_result
1036
- yield_result(None, results, tasks, save_extension, black_out_nsfw)
1037
- return
1038
- except Exception as e:
1039
- logger.std_error(f'[Fooocus] Worker error: {e}')
1040
-
1041
- if not async_task.is_finished:
1042
- async_task.set_result([], True, str(e))
1043
- worker_queue.finish_task(async_task.job_id)
1044
- logger.std_info(f"[Task Queue] Finish task with error, job_id={async_task.job_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
predict.py DELETED
@@ -1,316 +0,0 @@
1
- """
2
- Prediction interface for Cog ⚙️
3
- https://github.com/replicate/cog/blob/main/docs/python.md
4
- """
5
-
6
- import copy
7
- import os
8
- from typing import List
9
- import numpy as np
10
-
11
- from PIL import Image
12
- from cog import BasePredictor, BaseModel, Input, Path
13
- from fooocusapi.utils.lora_manager import LoraManager
14
- from fooocusapi.utils.file_utils import output_dir
15
- from fooocusapi.models.common.task import GenerationFinishReason
16
- from fooocusapi.configs.default import (
17
- available_aspect_ratios,
18
- uov_methods,
19
- outpaint_expansions,
20
- default_styles,
21
- default_base_model_name,
22
- default_refiner_model_name,
23
- default_loras,
24
- default_refiner_switch,
25
- default_cfg_scale,
26
- default_prompt_negative
27
- )
28
-
29
- from fooocusapi.parameters import ImageGenerationParams
30
- from fooocusapi.task_queue import TaskType
31
-
32
-
33
- class Output(BaseModel):
34
- """
35
- Output model
36
- """
37
- seeds: List[str]
38
- paths: List[Path]
39
-
40
-
41
- class Predictor(BasePredictor):
42
- """Predictor"""
43
- def setup(self) -> None:
44
- """
45
- Load the model into memory to make running multiple predictions efficient
46
- """
47
- from main import pre_setup
48
- pre_setup()
49
-
50
- def predict(
51
- self,
52
- prompt: str = Input(
53
- default='',
54
- description="Prompt for image generation"),
55
- negative_prompt: str = Input(
56
- default=default_prompt_negative,
57
- description="Negative prompt for image generation"),
58
- style_selections: str = Input(
59
- default=','.join(default_styles),
60
- description="Fooocus styles applied for image generation, separated by comma"),
61
- performance_selection: str = Input(
62
- default='Speed',
63
- choices=['Speed', 'Quality', 'Extreme Speed', 'Lightning'],
64
- description="Performance selection"),
65
- aspect_ratios_selection: str = Input(
66
- default='1152*896',
67
- choices=available_aspect_ratios,
68
- description="The generated image's size"),
69
- image_number: int = Input(
70
- default=1,
71
- ge=1, le=8,
72
- description="How many image to generate"),
73
- image_seed: int = Input(
74
- default=-1,
75
- description="Seed to generate image, -1 for random"),
76
- use_default_loras: bool = Input(
77
- default=True,
78
- description="Use default LoRAs"),
79
- loras_custom_urls: str = Input(
80
- default="",
81
- description="Custom LoRAs URLs in the format 'url,weight' provide multiple separated by ; (example 'url1,0.3;url2,0.1')"),
82
- sharpness: float = Input(
83
- default=2.0,
84
- ge=0.0, le=30.0),
85
- guidance_scale: float = Input(
86
- default=default_cfg_scale,
87
- ge=1.0, le=30.0),
88
- refiner_switch: float = Input(
89
- default=default_refiner_switch,
90
- ge=0.1, le=1.0),
91
- uov_input_image: Path = Input(
92
- default=None,
93
- description="Input image for upscale or variation, keep None for not upscale or variation"),
94
- uov_method: str = Input(
95
- default='Disabled',
96
- choices=uov_methods),
97
- uov_upscale_value: float = Input(
98
- default=0,
99
- description="Only when Upscale (Custom)"),
100
- inpaint_additional_prompt: str = Input(
101
- default='',
102
- description="Prompt for image generation"),
103
- inpaint_input_image: Path = Input(
104
- default=None,
105
- description="Input image for inpaint or outpaint, keep None for not inpaint or outpaint. Please noticed, `uov_input_image` has bigger priority is not None."),
106
- inpaint_input_mask: Path = Input(
107
- default=None,
108
- description="Input mask for inpaint"),
109
- outpaint_selections: str = Input(
110
- default='',
111
- description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
112
- outpaint_distance_left: int = Input(
113
- default=0,
114
- description="Outpaint expansion distance from Left of the image"),
115
- outpaint_distance_top: int = Input(
116
- default=0,
117
- description="Outpaint expansion distance from Top of the image"),
118
- outpaint_distance_right: int = Input(
119
- default=0,
120
- description="Outpaint expansion distance from Right of the image"),
121
- outpaint_distance_bottom: int = Input(
122
- default=0,
123
- description="Outpaint expansion distance from Bottom of the image"),
124
- cn_img1: Path = Input(
125
- default=None,
126
- description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
127
- cn_stop1: float = Input(
128
- default=None,
129
- ge=0, le=1,
130
- description="Stop at for image prompt, None for default value"),
131
- cn_weight1: float = Input(
132
- default=None,
133
- ge=0, le=2,
134
- description="Weight for image prompt, None for default value"),
135
- cn_type1: str = Input(
136
- default='ImagePrompt',
137
- choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
138
- description="ControlNet type for image prompt"),
139
- cn_img2: Path = Input(
140
- default=None,
141
- description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
142
- cn_stop2: float = Input(
143
- default=None,
144
- ge=0, le=1,
145
- description="Stop at for image prompt, None for default value"),
146
- cn_weight2: float = Input(
147
- default=None,
148
- ge=0, le=2,
149
- description="Weight for image prompt, None for default value"),
150
- cn_type2: str = Input(
151
- default='ImagePrompt',
152
- choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
153
- description="ControlNet type for image prompt"),
154
- cn_img3: Path = Input(
155
- default=None,
156
- description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
157
- cn_stop3: float = Input(
158
- default=None,
159
- ge=0, le=1,
160
- description="Stop at for image prompt, None for default value"),
161
- cn_weight3: float = Input(
162
- default=None,
163
- ge=0, le=2,
164
- description="Weight for image prompt, None for default value"),
165
- cn_type3: str = Input(
166
- default='ImagePrompt',
167
- choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
168
- description="ControlNet type for image prompt"),
169
- cn_img4: Path = Input(
170
- default=None,
171
- description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
172
- cn_stop4: float = Input(
173
- default=None,
174
- ge=0, le=1,
175
- description="Stop at for image prompt, None for default value"),
176
- cn_weight4: float = Input(
177
- default=None,
178
- ge=0, le=2,
179
- description="Weight for image prompt, None for default value"),
180
- cn_type4: str = Input(
181
- default='ImagePrompt',
182
- choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
183
- description="ControlNet type for image prompt")
184
- ) -> Output:
185
- """Run a single prediction on the model"""
186
- from modules import flags
187
- from modules.sdxl_styles import legal_style_names
188
- from fooocusapi.worker import blocking_get_task_result, worker_queue
189
-
190
- base_model_name = default_base_model_name
191
- refiner_model_name = default_refiner_model_name
192
-
193
- lora_manager = LoraManager()
194
-
195
- # Use default loras if selected
196
- loras = copy.copy(default_loras) if use_default_loras else []
197
-
198
- # add custom user loras if provided
199
- if loras_custom_urls:
200
- urls = [url.strip() for url in loras_custom_urls.split(';')]
201
-
202
- loras_with_weights = [url.split(',') for url in urls]
203
-
204
- custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights])
205
- custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in
206
- zip(custom_lora_paths, loras_with_weights)]
207
-
208
- loras.extend(custom_loras)
209
-
210
- style_selections_arr = []
211
- for s in style_selections.strip().split(','):
212
- style = s.strip()
213
- if style in legal_style_names:
214
- style_selections_arr.append(style)
215
-
216
- if uov_input_image is not None:
217
- im = Image.open(str(uov_input_image))
218
- uov_input_image = np.array(im)
219
-
220
- inpaint_input_image_dict = None
221
- if inpaint_input_image is not None:
222
- im = Image.open(str(inpaint_input_image))
223
- inpaint_input_image = np.array(im)
224
-
225
- if inpaint_input_mask is not None:
226
- im = Image.open(str(inpaint_input_mask))
227
- inpaint_input_mask = np.array(im)
228
-
229
- inpaint_input_image_dict = {
230
- 'image': inpaint_input_image,
231
- 'mask': inpaint_input_mask
232
- }
233
-
234
- outpaint_selections_arr = []
235
- for e in outpaint_selections.strip().split(','):
236
- expansion = e.strip()
237
- if expansion in outpaint_expansions:
238
- outpaint_selections_arr.append(expansion)
239
-
240
- image_prompts = []
241
- image_prompt_config = [
242
- (cn_img1, cn_stop1, cn_weight1, cn_type1),
243
- (cn_img2, cn_stop2, cn_weight2, cn_type2),
244
- (cn_img3, cn_stop3, cn_weight3, cn_type3),
245
- (cn_img4, cn_stop4, cn_weight4, cn_type4)]
246
- for config in image_prompt_config:
247
- cn_img, cn_stop, cn_weight, cn_type = config
248
- if cn_img is not None:
249
- im = Image.open(str(cn_img))
250
- cn_img = np.array(im)
251
- if cn_stop is None:
252
- cn_stop = flags.default_parameters[cn_type][0]
253
- if cn_weight is None:
254
- cn_weight = flags.default_parameters[cn_type][1]
255
- image_prompts.append((cn_img, cn_stop, cn_weight, cn_type))
256
-
257
- advanced_params = None
258
-
259
- params = ImageGenerationParams(
260
- prompt=prompt,
261
- negative_prompt=negative_prompt,
262
- style_selections=style_selections_arr,
263
- performance_selection=performance_selection,
264
- aspect_ratios_selection=aspect_ratios_selection,
265
- image_number=image_number,
266
- image_seed=image_seed,
267
- sharpness=sharpness,
268
- guidance_scale=guidance_scale,
269
- base_model_name=base_model_name,
270
- refiner_model_name=refiner_model_name,
271
- refiner_switch=refiner_switch,
272
- loras=loras,
273
- uov_input_image=uov_input_image,
274
- uov_method=uov_method,
275
- upscale_value=uov_upscale_value,
276
- outpaint_selections=outpaint_selections_arr,
277
- inpaint_input_image=inpaint_input_image_dict,
278
- image_prompts=image_prompts,
279
- advanced_params=advanced_params,
280
- inpaint_additional_prompt=inpaint_additional_prompt,
281
- outpaint_distance_left=outpaint_distance_left,
282
- outpaint_distance_top=outpaint_distance_top,
283
- outpaint_distance_right=outpaint_distance_right,
284
- outpaint_distance_bottom=outpaint_distance_bottom,
285
- save_meta=True,
286
- meta_scheme='fooocus',
287
- save_extension='png',
288
- save_name='',
289
- require_base64=False,
290
- )
291
-
292
- print(f"[Predictor Predict] Params: {params.__dict__}")
293
-
294
- async_task = worker_queue.add_task(
295
- TaskType.text_2_img,
296
- params)
297
-
298
- if async_task is None:
299
- print("[Task Queue] The task queue has reached limit")
300
- raise Exception("The task queue has reached limit.")
301
-
302
- results = blocking_get_task_result(async_task.job_id)
303
-
304
- output_paths: List[Path] = []
305
- output_seeds: List[str] = []
306
- for r in results:
307
- if r.finish_reason == GenerationFinishReason.success and r.im is not None:
308
- output_seeds.append(r.seed)
309
- output_paths.append(Path(os.path.join(output_dir, r.im)))
310
-
311
- print(f"[Predictor Predict] Finished with {len(output_paths)} images")
312
-
313
- if len(output_paths) == 0:
314
- raise Exception("Process failed.")
315
-
316
- return Output(seeds=output_seeds, paths=output_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- """
2
- Created By: ishwor subedi
3
- Date: 2024-07-19
4
- """
 
 
 
 
 
repositories/Fooocus/args_manager.py DELETED
@@ -1,55 +0,0 @@
1
- import ldm_patched.modules.args_parser as args_parser
2
-
3
- args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
4
-
5
- args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
6
- args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
7
- help="Disables preset selection in Gradio.")
8
-
9
- args_parser.parser.add_argument("--language", type=str, default='default',
10
- help="Translate UI using json files in [language] folder. "
11
- "For example, [--language example] will use [language/example.json] for translation.")
12
-
13
- # For example, https://github.com/lllyasviel/Fooocus/issues/849
14
- args_parser.parser.add_argument("--disable-offload-from-vram", action="store_true",
15
- help="Force loading models to vram when the unload can be avoided. "
16
- "Some Mac users may need this.")
17
-
18
- args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
19
- args_parser.parser.add_argument("--disable-image-log", action='store_true',
20
- help="Prevent writing images and logs to hard drive.")
21
-
22
- args_parser.parser.add_argument("--disable-analytics", action='store_true',
23
- help="Disables analytics for Gradio.")
24
-
25
- args_parser.parser.add_argument("--disable-metadata", action='store_true',
26
- help="Disables saving metadata to images.")
27
-
28
- args_parser.parser.add_argument("--disable-preset-download", action='store_true',
29
- help="Disables downloading models for presets", default=False)
30
-
31
- args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
32
- help="Disables automatic description of uov images when prompt is empty", default=False)
33
-
34
- args_parser.parser.add_argument("--always-download-new-model", action='store_true',
35
- help="Always download newer models ", default=False)
36
-
37
- args_parser.parser.set_defaults(
38
- disable_cuda_malloc=True,
39
- in_browser=True,
40
- port=None
41
- )
42
-
43
- args_parser.args = args_parser.parser.parse_args()
44
-
45
- # (Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
46
- args_parser.args.always_offload_from_vram = not args_parser.args.disable_offload_from_vram
47
-
48
- if args_parser.args.disable_analytics:
49
- import os
50
- os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
51
-
52
- if args_parser.args.disable_in_browser:
53
- args_parser.args.in_browser = False
54
-
55
- args = args_parser.args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/bert_config.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertModel"
4
- ],
5
- "attention_probs_dropout_prob": 0.1,
6
- "hidden_act": "gelu",
7
- "hidden_dropout_prob": 0.1,
8
- "hidden_size": 768,
9
- "initializer_range": 0.02,
10
- "intermediate_size": 3072,
11
- "layer_norm_eps": 1e-12,
12
- "max_position_embeddings": 512,
13
- "model_type": "bert",
14
- "num_attention_heads": 12,
15
- "num_hidden_layers": 12,
16
- "pad_token_id": 0,
17
- "type_vocab_size": 2,
18
- "vocab_size": 30522,
19
- "encoder_width": 768,
20
- "add_cross_attention": true
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml DELETED
@@ -1,33 +0,0 @@
1
- image_root: '/export/share/datasets/vision/coco/images/'
2
- ann_root: 'annotation'
3
- coco_gt_root: 'annotation/coco_gt'
4
-
5
- # set pretrained as a file path or an url
6
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
7
-
8
- # size of vit model; base or large
9
- vit: 'base'
10
- vit_grad_ckpt: False
11
- vit_ckpt_layer: 0
12
- batch_size: 32
13
- init_lr: 1e-5
14
-
15
- # vit: 'large'
16
- # vit_grad_ckpt: True
17
- # vit_ckpt_layer: 5
18
- # batch_size: 16
19
- # init_lr: 2e-6
20
-
21
- image_size: 384
22
-
23
- # generation configs
24
- max_length: 20
25
- min_length: 5
26
- num_beams: 3
27
- prompt: 'a picture of '
28
-
29
- # optimizer
30
- weight_decay: 0.05
31
- min_lr: 0
32
- max_epoch: 5
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/med_config.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertModel"
4
- ],
5
- "attention_probs_dropout_prob": 0.1,
6
- "hidden_act": "gelu",
7
- "hidden_dropout_prob": 0.1,
8
- "hidden_size": 768,
9
- "initializer_range": 0.02,
10
- "intermediate_size": 3072,
11
- "layer_norm_eps": 1e-12,
12
- "max_position_embeddings": 512,
13
- "model_type": "bert",
14
- "num_attention_heads": 12,
15
- "num_hidden_layers": 12,
16
- "pad_token_id": 0,
17
- "type_vocab_size": 2,
18
- "vocab_size": 30524,
19
- "encoder_width": 768,
20
- "add_cross_attention": true
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/nlvr.yaml DELETED
@@ -1,21 +0,0 @@
1
- image_root: '/export/share/datasets/vision/NLVR2/'
2
- ann_root: 'annotation'
3
-
4
- # set pretrained as a file path or an url
5
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6
-
7
- #size of vit model; base or large
8
- vit: 'base'
9
- batch_size_train: 16
10
- batch_size_test: 64
11
- vit_grad_ckpt: False
12
- vit_ckpt_layer: 0
13
- max_epoch: 15
14
-
15
- image_size: 384
16
-
17
- # optimizer
18
- weight_decay: 0.05
19
- init_lr: 3e-5
20
- min_lr: 0
21
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/nocaps.yaml DELETED
@@ -1,15 +0,0 @@
1
- image_root: '/export/share/datasets/vision/nocaps/'
2
- ann_root: 'annotation'
3
-
4
- # set pretrained as a file path or an url
5
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
6
-
7
- vit: 'base'
8
- batch_size: 32
9
-
10
- image_size: 384
11
-
12
- max_length: 20
13
- min_length: 5
14
- num_beams: 3
15
- prompt: 'a picture of '
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/pretrain.yaml DELETED
@@ -1,27 +0,0 @@
1
- train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2
- '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3
- ]
4
- laion_path: ''
5
-
6
- # size of vit model; base or large
7
- vit: 'base'
8
- vit_grad_ckpt: False
9
- vit_ckpt_layer: 0
10
-
11
- image_size: 224
12
- batch_size: 75
13
-
14
- queue_size: 57600
15
- alpha: 0.4
16
-
17
- # optimizer
18
- weight_decay: 0.05
19
- init_lr: 3e-4
20
- min_lr: 1e-6
21
- warmup_lr: 1e-6
22
- lr_decay_rate: 0.9
23
- max_epoch: 20
24
- warmup_steps: 3000
25
-
26
-
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml DELETED
@@ -1,34 +0,0 @@
1
- image_root: '/export/share/datasets/vision/coco/images/'
2
- ann_root: 'annotation'
3
- dataset: 'coco'
4
-
5
- # set pretrained as a file path or an url
6
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7
-
8
- # size of vit model; base or large
9
-
10
- vit: 'base'
11
- batch_size_train: 32
12
- batch_size_test: 64
13
- vit_grad_ckpt: True
14
- vit_ckpt_layer: 4
15
- init_lr: 1e-5
16
-
17
- # vit: 'large'
18
- # batch_size_train: 16
19
- # batch_size_test: 32
20
- # vit_grad_ckpt: True
21
- # vit_ckpt_layer: 12
22
- # init_lr: 5e-6
23
-
24
- image_size: 384
25
- queue_size: 57600
26
- alpha: 0.4
27
- k_test: 256
28
- negative_all_rank: True
29
-
30
- # optimizer
31
- weight_decay: 0.05
32
- min_lr: 0
33
- max_epoch: 6
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml DELETED
@@ -1,34 +0,0 @@
1
- image_root: '/export/share/datasets/vision/flickr30k/'
2
- ann_root: 'annotation'
3
- dataset: 'flickr'
4
-
5
- # set pretrained as a file path or an url
6
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7
-
8
- # size of vit model; base or large
9
-
10
- vit: 'base'
11
- batch_size_train: 32
12
- batch_size_test: 64
13
- vit_grad_ckpt: True
14
- vit_ckpt_layer: 4
15
- init_lr: 1e-5
16
-
17
- # vit: 'large'
18
- # batch_size_train: 16
19
- # batch_size_test: 32
20
- # vit_grad_ckpt: True
21
- # vit_ckpt_layer: 10
22
- # init_lr: 5e-6
23
-
24
- image_size: 384
25
- queue_size: 57600
26
- alpha: 0.4
27
- k_test: 128
28
- negative_all_rank: False
29
-
30
- # optimizer
31
- weight_decay: 0.05
32
- min_lr: 0
33
- max_epoch: 6
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml DELETED
@@ -1,12 +0,0 @@
1
- video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
2
- ann_root: 'annotation'
3
-
4
- # set pretrained as a file path or an url
5
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
6
-
7
- # size of vit model; base or large
8
- vit: 'base'
9
- batch_size: 64
10
- k_test: 128
11
- image_size: 384
12
- num_frm_test: 8
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/configs/vqa.yaml DELETED
@@ -1,25 +0,0 @@
1
- vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
2
- vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
3
- train_files: ['vqa_train','vqa_val','vg_qa']
4
- ann_root: 'annotation'
5
-
6
- # set pretrained as a file path or an url
7
- pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
8
-
9
- # size of vit model; base or large
10
- vit: 'base'
11
- batch_size_train: 16
12
- batch_size_test: 32
13
- vit_grad_ckpt: False
14
- vit_ckpt_layer: 0
15
- init_lr: 2e-5
16
-
17
- image_size: 480
18
-
19
- k_test: 128
20
- inference: 'rank'
21
-
22
- # optimizer
23
- weight_decay: 0.05
24
- min_lr: 0
25
- max_epoch: 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.1,
6
- "gradient_checkpointing": false,
7
- "hidden_act": "gelu",
8
- "hidden_dropout_prob": 0.1,
9
- "hidden_size": 768,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 3072,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 512,
14
- "model_type": "bert",
15
- "num_attention_heads": 12,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0.dev0",
20
- "type_vocab_size": 2,
21
- "use_cache": true,
22
- "vocab_size": 30522
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "do_lower_case": true
3
- }
 
 
 
 
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt DELETED
The diff for this file is too large to render. See raw diff
 
repositories/Fooocus/extras/BLIP/models/blip.py DELETED
@@ -1,239 +0,0 @@
1
- '''
2
- * Copyright (c) 2022, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- '''
8
- import warnings
9
- warnings.filterwarnings("ignore")
10
-
11
- from extras.BLIP.models.vit import VisionTransformer, interpolate_pos_embed
12
- from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
13
- from transformers import BertTokenizer
14
-
15
- import torch
16
- from torch import nn
17
- import torch.nn.functional as F
18
-
19
- import os
20
- from urllib.parse import urlparse
21
- from timm.models.hub import download_cached_file
22
-
23
- class BLIP_Base(nn.Module):
24
- def __init__(self,
25
- med_config = 'configs/med_config.json',
26
- image_size = 224,
27
- vit = 'base',
28
- vit_grad_ckpt = False,
29
- vit_ckpt_layer = 0,
30
- ):
31
- """
32
- Args:
33
- med_config (str): path for the mixture of encoder-decoder model's configuration file
34
- image_size (int): input image size
35
- vit (str): model size of vision transformer
36
- """
37
- super().__init__()
38
-
39
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
40
- self.tokenizer = init_tokenizer()
41
- med_config = BertConfig.from_json_file(med_config)
42
- med_config.encoder_width = vision_width
43
- self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
44
-
45
-
46
- def forward(self, image, caption, mode):
47
-
48
- assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
49
- text = self.tokenizer(caption, return_tensors="pt").to(image.device)
50
-
51
- if mode=='image':
52
- # return image features
53
- image_embeds = self.visual_encoder(image)
54
- return image_embeds
55
-
56
- elif mode=='text':
57
- # return text features
58
- text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
59
- return_dict = True, mode = 'text')
60
- return text_output.last_hidden_state
61
-
62
- elif mode=='multimodal':
63
- # return multimodel features
64
- image_embeds = self.visual_encoder(image)
65
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
66
-
67
- text.input_ids[:,0] = self.tokenizer.enc_token_id
68
- output = self.text_encoder(text.input_ids,
69
- attention_mask = text.attention_mask,
70
- encoder_hidden_states = image_embeds,
71
- encoder_attention_mask = image_atts,
72
- return_dict = True,
73
- )
74
- return output.last_hidden_state
75
-
76
-
77
-
78
- class BLIP_Decoder(nn.Module):
79
- def __init__(self,
80
- med_config = 'configs/med_config.json',
81
- image_size = 384,
82
- vit = 'base',
83
- vit_grad_ckpt = False,
84
- vit_ckpt_layer = 0,
85
- prompt = 'a picture of ',
86
- ):
87
- """
88
- Args:
89
- med_config (str): path for the mixture of encoder-decoder model's configuration file
90
- image_size (int): input image size
91
- vit (str): model size of vision transformer
92
- """
93
- super().__init__()
94
-
95
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
96
- self.tokenizer = init_tokenizer()
97
- med_config = BertConfig.from_json_file(med_config)
98
- med_config.encoder_width = vision_width
99
- self.text_decoder = BertLMHeadModel(config=med_config)
100
-
101
- self.prompt = prompt
102
- self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
103
-
104
-
105
- def forward(self, image, caption):
106
-
107
- image_embeds = self.visual_encoder(image)
108
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
109
-
110
- text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
111
-
112
- text.input_ids[:,0] = self.tokenizer.bos_token_id
113
-
114
- decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
115
- decoder_targets[:,:self.prompt_length] = -100
116
-
117
- decoder_output = self.text_decoder(text.input_ids,
118
- attention_mask = text.attention_mask,
119
- encoder_hidden_states = image_embeds,
120
- encoder_attention_mask = image_atts,
121
- labels = decoder_targets,
122
- return_dict = True,
123
- )
124
- loss_lm = decoder_output.loss
125
-
126
- return loss_lm
127
-
128
- def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
129
- image_embeds = self.visual_encoder(image)
130
-
131
- if not sample:
132
- image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
133
-
134
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
135
- model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
136
-
137
- prompt = [self.prompt] * image.size(0)
138
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
139
- input_ids[:,0] = self.tokenizer.bos_token_id
140
- input_ids = input_ids[:, :-1]
141
-
142
- if sample:
143
- #nucleus sampling
144
- outputs = self.text_decoder.generate(input_ids=input_ids,
145
- max_length=max_length,
146
- min_length=min_length,
147
- do_sample=True,
148
- top_p=top_p,
149
- num_return_sequences=1,
150
- eos_token_id=self.tokenizer.sep_token_id,
151
- pad_token_id=self.tokenizer.pad_token_id,
152
- repetition_penalty=1.1,
153
- **model_kwargs)
154
- else:
155
- #beam search
156
- outputs = self.text_decoder.generate(input_ids=input_ids,
157
- max_length=max_length,
158
- min_length=min_length,
159
- num_beams=num_beams,
160
- eos_token_id=self.tokenizer.sep_token_id,
161
- pad_token_id=self.tokenizer.pad_token_id,
162
- repetition_penalty=repetition_penalty,
163
- **model_kwargs)
164
-
165
- captions = []
166
- for output in outputs:
167
- caption = self.tokenizer.decode(output, skip_special_tokens=True)
168
- captions.append(caption[len(self.prompt):])
169
- return captions
170
-
171
-
172
- def blip_decoder(pretrained='',**kwargs):
173
- model = BLIP_Decoder(**kwargs)
174
- if pretrained:
175
- model,msg = load_checkpoint(model,pretrained)
176
- assert(len(msg.missing_keys)==0)
177
- return model
178
-
179
- def blip_feature_extractor(pretrained='',**kwargs):
180
- model = BLIP_Base(**kwargs)
181
- if pretrained:
182
- model,msg = load_checkpoint(model,pretrained)
183
- assert(len(msg.missing_keys)==0)
184
- return model
185
-
186
- def init_tokenizer():
187
- tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bert_tokenizer")
188
- tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
189
- tokenizer.add_special_tokens({'bos_token':'[DEC]'})
190
- tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
191
- tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
192
- return tokenizer
193
-
194
-
195
- def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
196
-
197
- assert vit in ['base', 'large'], "vit parameter must be base or large"
198
- if vit=='base':
199
- vision_width = 768
200
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
201
- num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
202
- drop_path_rate=0 or drop_path_rate
203
- )
204
- elif vit=='large':
205
- vision_width = 1024
206
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
207
- num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
208
- drop_path_rate=0.1 or drop_path_rate
209
- )
210
- return visual_encoder, vision_width
211
-
212
- def is_url(url_or_filename):
213
- parsed = urlparse(url_or_filename)
214
- return parsed.scheme in ("http", "https")
215
-
216
- def load_checkpoint(model,url_or_filename):
217
- if is_url(url_or_filename):
218
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
219
- checkpoint = torch.load(cached_file, map_location='cpu')
220
- elif os.path.isfile(url_or_filename):
221
- checkpoint = torch.load(url_or_filename, map_location='cpu')
222
- else:
223
- raise RuntimeError('checkpoint url or path is invalid')
224
-
225
- state_dict = checkpoint['model']
226
-
227
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
228
- if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
229
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
230
- model.visual_encoder_m)
231
- for key in model.state_dict().keys():
232
- if key in state_dict.keys():
233
- if state_dict[key].shape!=model.state_dict()[key].shape:
234
- del state_dict[key]
235
-
236
- msg = model.load_state_dict(state_dict,strict=False)
237
- print('load checkpoint from %s'%url_or_filename)
238
- return model,msg
239
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/models/blip_itm.py DELETED
@@ -1,76 +0,0 @@
1
- from extras.BLIP.models.med import BertConfig, BertModel
2
- from transformers import BertTokenizer
3
-
4
- import torch
5
- from torch import nn
6
- import torch.nn.functional as F
7
-
8
- from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
9
-
10
- class BLIP_ITM(nn.Module):
11
- def __init__(self,
12
- med_config = 'configs/med_config.json',
13
- image_size = 384,
14
- vit = 'base',
15
- vit_grad_ckpt = False,
16
- vit_ckpt_layer = 0,
17
- embed_dim = 256,
18
- ):
19
- """
20
- Args:
21
- med_config (str): path for the mixture of encoder-decoder model's configuration file
22
- image_size (int): input image size
23
- vit (str): model size of vision transformer
24
- """
25
- super().__init__()
26
-
27
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
28
- self.tokenizer = init_tokenizer()
29
- med_config = BertConfig.from_json_file(med_config)
30
- med_config.encoder_width = vision_width
31
- self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
32
-
33
- text_width = self.text_encoder.config.hidden_size
34
-
35
- self.vision_proj = nn.Linear(vision_width, embed_dim)
36
- self.text_proj = nn.Linear(text_width, embed_dim)
37
-
38
- self.itm_head = nn.Linear(text_width, 2)
39
-
40
-
41
- def forward(self, image, caption, match_head='itm'):
42
-
43
- image_embeds = self.visual_encoder(image)
44
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
45
-
46
- text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
47
- return_tensors="pt").to(image.device)
48
-
49
-
50
- if match_head=='itm':
51
- output = self.text_encoder(text.input_ids,
52
- attention_mask = text.attention_mask,
53
- encoder_hidden_states = image_embeds,
54
- encoder_attention_mask = image_atts,
55
- return_dict = True,
56
- )
57
- itm_output = self.itm_head(output.last_hidden_state[:,0,:])
58
- return itm_output
59
-
60
- elif match_head=='itc':
61
- text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
62
- return_dict = True, mode = 'text')
63
- image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
64
- text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
65
-
66
- sim = image_feat @ text_feat.t()
67
- return sim
68
-
69
-
70
- def blip_itm(pretrained='',**kwargs):
71
- model = BLIP_ITM(**kwargs)
72
- if pretrained:
73
- model,msg = load_checkpoint(model,pretrained)
74
- assert(len(msg.missing_keys)==0)
75
- return model
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/models/blip_nlvr.py DELETED
@@ -1,105 +0,0 @@
1
- from extras.BLIP.models.med import BertConfig
2
- from extras.BLIP.models.nlvr_encoder import BertModel
3
- from extras.BLIP.models.vit import interpolate_pos_embed
4
- from extras.BLIP.models.blip import create_vit, init_tokenizer, is_url
5
-
6
- from timm.models.hub import download_cached_file
7
-
8
- import torch
9
- from torch import nn
10
- import torch.nn.functional as F
11
- from transformers import BertTokenizer
12
- import numpy as np
13
- import os
14
-
15
-
16
- class BLIP_NLVR(nn.Module):
17
- def __init__(self,
18
- med_config = 'configs/med_config.json',
19
- image_size = 480,
20
- vit = 'base',
21
- vit_grad_ckpt = False,
22
- vit_ckpt_layer = 0,
23
- ):
24
- """
25
- Args:
26
- med_config (str): path for the mixture of encoder-decoder model's configuration file
27
- image_size (int): input image size
28
- vit (str): model size of vision transformer
29
- """
30
- super().__init__()
31
-
32
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
33
- self.tokenizer = init_tokenizer()
34
- med_config = BertConfig.from_json_file(med_config)
35
- med_config.encoder_width = vision_width
36
- self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
37
-
38
- self.cls_head = nn.Sequential(
39
- nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
40
- nn.ReLU(),
41
- nn.Linear(self.text_encoder.config.hidden_size, 2)
42
- )
43
-
44
- def forward(self, image, text, targets, train=True):
45
-
46
- image_embeds = self.visual_encoder(image)
47
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
48
- image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
49
-
50
- text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
51
- text.input_ids[:,0] = self.tokenizer.enc_token_id
52
-
53
- output = self.text_encoder(text.input_ids,
54
- attention_mask = text.attention_mask,
55
- encoder_hidden_states = [image0_embeds,image1_embeds],
56
- encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
57
- image_atts[image0_embeds.size(0):]],
58
- return_dict = True,
59
- )
60
- hidden_state = output.last_hidden_state[:,0,:]
61
- prediction = self.cls_head(hidden_state)
62
-
63
- if train:
64
- loss = F.cross_entropy(prediction, targets)
65
- return loss
66
- else:
67
- return prediction
68
-
69
- def blip_nlvr(pretrained='',**kwargs):
70
- model = BLIP_NLVR(**kwargs)
71
- if pretrained:
72
- model,msg = load_checkpoint(model,pretrained)
73
- print("missing keys:")
74
- print(msg.missing_keys)
75
- return model
76
-
77
-
78
- def load_checkpoint(model,url_or_filename):
79
- if is_url(url_or_filename):
80
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
81
- checkpoint = torch.load(cached_file, map_location='cpu')
82
- elif os.path.isfile(url_or_filename):
83
- checkpoint = torch.load(url_or_filename, map_location='cpu')
84
- else:
85
- raise RuntimeError('checkpoint url or path is invalid')
86
- state_dict = checkpoint['model']
87
-
88
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
89
-
90
- for key in list(state_dict.keys()):
91
- if 'crossattention.self.' in key:
92
- new_key0 = key.replace('self','self0')
93
- new_key1 = key.replace('self','self1')
94
- state_dict[new_key0] = state_dict[key]
95
- state_dict[new_key1] = state_dict[key]
96
- elif 'crossattention.output.dense.' in key:
97
- new_key0 = key.replace('dense','dense0')
98
- new_key1 = key.replace('dense','dense1')
99
- state_dict[new_key0] = state_dict[key]
100
- state_dict[new_key1] = state_dict[key]
101
-
102
- msg = model.load_state_dict(state_dict,strict=False)
103
- print('load checkpoint from %s'%url_or_filename)
104
- return model,msg
105
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/models/blip_pretrain.py DELETED
@@ -1,339 +0,0 @@
1
- '''
2
- * Copyright (c) 2022, salesforce.com, inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: BSD-3-Clause
5
- * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
- * By Junnan Li
7
- '''
8
- from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
9
- from transformers import BertTokenizer
10
- import transformers
11
- transformers.logging.set_verbosity_error()
12
-
13
- import torch
14
- from torch import nn
15
- import torch.nn.functional as F
16
-
17
- from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
18
-
19
- class BLIP_Pretrain(nn.Module):
20
- def __init__(self,
21
- med_config = 'configs/bert_config.json',
22
- image_size = 224,
23
- vit = 'base',
24
- vit_grad_ckpt = False,
25
- vit_ckpt_layer = 0,
26
- embed_dim = 256,
27
- queue_size = 57600,
28
- momentum = 0.995,
29
- ):
30
- """
31
- Args:
32
- med_config (str): path for the mixture of encoder-decoder model's configuration file
33
- image_size (int): input image size
34
- vit (str): model size of vision transformer
35
- """
36
- super().__init__()
37
-
38
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
39
-
40
- if vit=='base':
41
- checkpoint = torch.hub.load_state_dict_from_url(
42
- url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
43
- map_location="cpu", check_hash=True)
44
- state_dict = checkpoint["model"]
45
- msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
46
- elif vit=='large':
47
- from timm.models.helpers import load_custom_pretrained
48
- from timm.models.vision_transformer import default_cfgs
49
- load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
50
-
51
- self.tokenizer = init_tokenizer()
52
- encoder_config = BertConfig.from_json_file(med_config)
53
- encoder_config.encoder_width = vision_width
54
- self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
55
- self.text_encoder.resize_token_embeddings(len(self.tokenizer))
56
-
57
- text_width = self.text_encoder.config.hidden_size
58
-
59
- self.vision_proj = nn.Linear(vision_width, embed_dim)
60
- self.text_proj = nn.Linear(text_width, embed_dim)
61
-
62
- self.itm_head = nn.Linear(text_width, 2)
63
-
64
- # create momentum encoders
65
- self.visual_encoder_m, vision_width = create_vit(vit,image_size)
66
- self.vision_proj_m = nn.Linear(vision_width, embed_dim)
67
- self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
68
- self.text_proj_m = nn.Linear(text_width, embed_dim)
69
-
70
- self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
71
- [self.vision_proj,self.vision_proj_m],
72
- [self.text_encoder,self.text_encoder_m],
73
- [self.text_proj,self.text_proj_m],
74
- ]
75
- self.copy_params()
76
-
77
- # create the queue
78
- self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
79
- self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
80
- self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
81
-
82
- self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
83
- self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
84
-
85
- self.queue_size = queue_size
86
- self.momentum = momentum
87
- self.temp = nn.Parameter(0.07*torch.ones([]))
88
-
89
- # create the decoder
90
- decoder_config = BertConfig.from_json_file(med_config)
91
- decoder_config.encoder_width = vision_width
92
- self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
93
- self.text_decoder.resize_token_embeddings(len(self.tokenizer))
94
- tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
95
-
96
-
97
- def forward(self, image, caption, alpha):
98
- with torch.no_grad():
99
- self.temp.clamp_(0.001,0.5)
100
-
101
- image_embeds = self.visual_encoder(image)
102
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
103
- image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
104
-
105
- text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
106
- return_tensors="pt").to(image.device)
107
- text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
108
- return_dict = True, mode = 'text')
109
- text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
110
-
111
- # get momentum features
112
- with torch.no_grad():
113
- self._momentum_update()
114
- image_embeds_m = self.visual_encoder_m(image)
115
- image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
116
- image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
117
-
118
- text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
119
- return_dict = True, mode = 'text')
120
- text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
121
- text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
122
-
123
- sim_i2t_m = image_feat_m @ text_feat_all / self.temp
124
- sim_t2i_m = text_feat_m @ image_feat_all / self.temp
125
-
126
- sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
127
- sim_targets.fill_diagonal_(1)
128
-
129
- sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
130
- sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
131
-
132
- sim_i2t = image_feat @ text_feat_all / self.temp
133
- sim_t2i = text_feat @ image_feat_all / self.temp
134
-
135
- loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
136
- loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
137
-
138
- loss_ita = (loss_i2t+loss_t2i)/2
139
-
140
- self._dequeue_and_enqueue(image_feat_m, text_feat_m)
141
-
142
- ###============== Image-text Matching ===================###
143
- encoder_input_ids = text.input_ids.clone()
144
- encoder_input_ids[:,0] = self.tokenizer.enc_token_id
145
-
146
- # forward the positve image-text pair
147
- bs = image.size(0)
148
- output_pos = self.text_encoder(encoder_input_ids,
149
- attention_mask = text.attention_mask,
150
- encoder_hidden_states = image_embeds,
151
- encoder_attention_mask = image_atts,
152
- return_dict = True,
153
- )
154
- with torch.no_grad():
155
- weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
156
- weights_t2i.fill_diagonal_(0)
157
- weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
158
- weights_i2t.fill_diagonal_(0)
159
-
160
- # select a negative image for each text
161
- image_embeds_neg = []
162
- for b in range(bs):
163
- neg_idx = torch.multinomial(weights_t2i[b], 1).item()
164
- image_embeds_neg.append(image_embeds[neg_idx])
165
- image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
166
-
167
- # select a negative text for each image
168
- text_ids_neg = []
169
- text_atts_neg = []
170
- for b in range(bs):
171
- neg_idx = torch.multinomial(weights_i2t[b], 1).item()
172
- text_ids_neg.append(encoder_input_ids[neg_idx])
173
- text_atts_neg.append(text.attention_mask[neg_idx])
174
-
175
- text_ids_neg = torch.stack(text_ids_neg,dim=0)
176
- text_atts_neg = torch.stack(text_atts_neg,dim=0)
177
-
178
- text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
179
- text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
180
-
181
- image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
182
- image_atts_all = torch.cat([image_atts,image_atts],dim=0)
183
-
184
- output_neg = self.text_encoder(text_ids_all,
185
- attention_mask = text_atts_all,
186
- encoder_hidden_states = image_embeds_all,
187
- encoder_attention_mask = image_atts_all,
188
- return_dict = True,
189
- )
190
-
191
- vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
192
- vl_output = self.itm_head(vl_embeddings)
193
-
194
- itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
195
- dim=0).to(image.device)
196
- loss_itm = F.cross_entropy(vl_output, itm_labels)
197
-
198
- ##================= LM ========================##
199
- decoder_input_ids = text.input_ids.clone()
200
- decoder_input_ids[:,0] = self.tokenizer.bos_token_id
201
- decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
202
-
203
- decoder_output = self.text_decoder(decoder_input_ids,
204
- attention_mask = text.attention_mask,
205
- encoder_hidden_states = image_embeds,
206
- encoder_attention_mask = image_atts,
207
- labels = decoder_targets,
208
- return_dict = True,
209
- )
210
-
211
- loss_lm = decoder_output.loss
212
- return loss_ita, loss_itm, loss_lm
213
-
214
-
215
-
216
- @torch.no_grad()
217
- def copy_params(self):
218
- for model_pair in self.model_pairs:
219
- for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
220
- param_m.data.copy_(param.data) # initialize
221
- param_m.requires_grad = False # not update by gradient
222
-
223
-
224
- @torch.no_grad()
225
- def _momentum_update(self):
226
- for model_pair in self.model_pairs:
227
- for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
228
- param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
229
-
230
-
231
- @torch.no_grad()
232
- def _dequeue_and_enqueue(self, image_feat, text_feat):
233
- # gather keys before updating queue
234
- image_feats = concat_all_gather(image_feat)
235
- text_feats = concat_all_gather(text_feat)
236
-
237
- batch_size = image_feats.shape[0]
238
-
239
- ptr = int(self.queue_ptr)
240
- assert self.queue_size % batch_size == 0 # for simplicity
241
-
242
- # replace the keys at ptr (dequeue and enqueue)
243
- self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
244
- self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
245
- ptr = (ptr + batch_size) % self.queue_size # move pointer
246
-
247
- self.queue_ptr[0] = ptr
248
-
249
-
250
- def blip_pretrain(**kwargs):
251
- model = BLIP_Pretrain(**kwargs)
252
- return model
253
-
254
-
255
- @torch.no_grad()
256
- def concat_all_gather(tensor):
257
- """
258
- Performs all_gather operation on the provided tensors.
259
- *** Warning ***: torch.distributed.all_gather has no gradient.
260
- """
261
- tensors_gather = [torch.ones_like(tensor)
262
- for _ in range(torch.distributed.get_world_size())]
263
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
264
-
265
- output = torch.cat(tensors_gather, dim=0)
266
- return output
267
-
268
-
269
- from typing import List
270
- def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
271
- uninitialized_encoder_weights: List[str] = []
272
- if decoder.__class__ != encoder.__class__:
273
- print(
274
- f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
275
- )
276
-
277
- def tie_encoder_to_decoder_recursively(
278
- decoder_pointer: nn.Module,
279
- encoder_pointer: nn.Module,
280
- module_name: str,
281
- uninitialized_encoder_weights: List[str],
282
- skip_key: str,
283
- depth=0,
284
- ):
285
- assert isinstance(decoder_pointer, nn.Module) and isinstance(
286
- encoder_pointer, nn.Module
287
- ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
288
- if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
289
- assert hasattr(encoder_pointer, "weight")
290
- encoder_pointer.weight = decoder_pointer.weight
291
- if hasattr(decoder_pointer, "bias"):
292
- assert hasattr(encoder_pointer, "bias")
293
- encoder_pointer.bias = decoder_pointer.bias
294
- print(module_name+' is tied')
295
- return
296
-
297
- encoder_modules = encoder_pointer._modules
298
- decoder_modules = decoder_pointer._modules
299
- if len(decoder_modules) > 0:
300
- assert (
301
- len(encoder_modules) > 0
302
- ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
303
-
304
- all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
305
- encoder_layer_pos = 0
306
- for name, module in decoder_modules.items():
307
- if name.isdigit():
308
- encoder_name = str(int(name) + encoder_layer_pos)
309
- decoder_name = name
310
- if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
311
- encoder_modules
312
- ) != len(decoder_modules):
313
- # this can happen if the name corresponds to the position in a list module list of layers
314
- # in this case the decoder has added a cross-attention that the encoder does not have
315
- # thus skip this step and subtract one layer pos from encoder
316
- encoder_layer_pos -= 1
317
- continue
318
- elif name not in encoder_modules:
319
- continue
320
- elif depth > 500:
321
- raise ValueError(
322
- "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
323
- )
324
- else:
325
- decoder_name = encoder_name = name
326
- tie_encoder_to_decoder_recursively(
327
- decoder_modules[decoder_name],
328
- encoder_modules[encoder_name],
329
- module_name + "/" + name,
330
- uninitialized_encoder_weights,
331
- skip_key,
332
- depth=depth + 1,
333
- )
334
- all_encoder_weights.remove(module_name + "/" + encoder_name)
335
-
336
- uninitialized_encoder_weights += list(all_encoder_weights)
337
-
338
- # tie weights recursively
339
- tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
repositories/Fooocus/extras/BLIP/models/blip_retrieval.py DELETED
@@ -1,319 +0,0 @@
1
- from extras.BLIP.models.med import BertConfig, BertModel
2
- from transformers import BertTokenizer
3
-
4
- import torch
5
- from torch import nn
6
- import torch.nn.functional as F
7
-
8
- from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
9
-
10
- class BLIP_Retrieval(nn.Module):
11
- def __init__(self,
12
- med_config = 'configs/med_config.json',
13
- image_size = 384,
14
- vit = 'base',
15
- vit_grad_ckpt = False,
16
- vit_ckpt_layer = 0,
17
- embed_dim = 256,
18
- queue_size = 57600,
19
- momentum = 0.995,
20
- negative_all_rank = False,
21
- ):
22
- """
23
- Args:
24
- med_config (str): path for the mixture of encoder-decoder model's configuration file
25
- image_size (int): input image size
26
- vit (str): model size of vision transformer
27
- """
28
- super().__init__()
29
-
30
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
31
- self.tokenizer = init_tokenizer()
32
- med_config = BertConfig.from_json_file(med_config)
33
- med_config.encoder_width = vision_width
34
- self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35
-
36
- text_width = self.text_encoder.config.hidden_size
37
-
38
- self.vision_proj = nn.Linear(vision_width, embed_dim)
39
- self.text_proj = nn.Linear(text_width, embed_dim)
40
-
41
- self.itm_head = nn.Linear(text_width, 2)
42
-
43
- # create momentum encoders
44
- self.visual_encoder_m, vision_width = create_vit(vit,image_size)
45
- self.vision_proj_m = nn.Linear(vision_width, embed_dim)
46
- self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
47
- self.text_proj_m = nn.Linear(text_width, embed_dim)
48
-
49
- self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
50
- [self.vision_proj,self.vision_proj_m],
51
- [self.text_encoder,self.text_encoder_m],
52
- [self.text_proj,self.text_proj_m],
53
- ]
54
- self.copy_params()
55
-
56
- # create the queue
57
- self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
58
- self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
59
- self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
60
- self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
61
-
62
- self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
63
- self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
64
-
65
- self.queue_size = queue_size
66
- self.momentum = momentum
67
- self.temp = nn.Parameter(0.07*torch.ones([]))
68
-
69
- self.negative_all_rank = negative_all_rank
70
-
71
-
72
- def forward(self, image, caption, alpha, idx):
73
- with torch.no_grad():
74
- self.temp.clamp_(0.001,0.5)
75
-
76
- image_embeds = self.visual_encoder(image)
77
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
78
- image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
79
-
80
- text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
81
- return_tensors="pt").to(image.device)
82
-
83
- text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
84
- return_dict = True, mode = 'text')
85
- text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
86
-
87
- ###============== Image-text Contrastive Learning ===================###
88
- idx = idx.view(-1,1)
89
- idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
90
- pos_idx = torch.eq(idx, idx_all).float()
91
- sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
92
-
93
- # get momentum features
94
- with torch.no_grad():
95
- self._momentum_update()
96
- image_embeds_m = self.visual_encoder_m(image)
97
- image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
98
- image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
99
-
100
- text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
101
- return_dict = True, mode = 'text')
102
- text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
103
- text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
104
-
105
- sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
106
- sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
107
-
108
- sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
109
- sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
110
-
111
- sim_i2t = image_feat @ text_feat_m_all / self.temp
112
- sim_t2i = text_feat @ image_feat_m_all / self.temp
113
-
114
- loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
115
- loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
116
-
117
- loss_ita = (loss_i2t+loss_t2i)/2
118
-
119
- idxs = concat_all_gather(idx)
120
- self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
121
-
122
- ###============== Image-text Matching ===================###
123
- encoder_input_ids = text.input_ids.clone()
124
- encoder_input_ids[:,0] = self.tokenizer.enc_token_id
125
-
126
- # forward the positve image-text pair
127
- bs = image.size(0)
128
- output_pos = self.text_encoder(encoder_input_ids,
129
- attention_mask = text.attention_mask,
130
- encoder_hidden_states = image_embeds,
131
- encoder_attention_mask = image_atts,
132
- return_dict = True,
133
- )
134
-
135
-
136
- if self.negative_all_rank:
137
- # compute sample similarity
138
- with torch.no_grad():
139
- mask = torch.eq(idx, idxs.t())
140
-
141
- image_feat_world = concat_all_gather(image_feat)
142
- text_feat_world = concat_all_gather(text_feat)
143
-
144
- sim_i2t = image_feat @ text_feat_world.t() / self.temp
145
- sim_t2i = text_feat @ image_feat_world.t() / self.temp
146
-
147
- weights_i2t = F.softmax(sim_i2t,dim=1)
148
- weights_i2t.masked_fill_(mask, 0)
149
-
150
- weights_t2i = F.softmax(sim_t2i,dim=1)
151
- weights_t2i.masked_fill_(mask, 0)
152
-
153
- image_embeds_world = all_gather_with_grad(image_embeds)
154
-
155
- # select a negative image (from all ranks) for each text
156
- image_embeds_neg = []
157
- for b in range(bs):
158
- neg_idx = torch.multinomial(weights_t2i[b], 1).item()
159
- image_embeds_neg.append(image_embeds_world[neg_idx])
160
- image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
161
-
162
- # select a negative text (from all ranks) for each image
163
- input_ids_world = concat_all_gather(encoder_input_ids)
164
- att_mask_world = concat_all_gather(text.attention_mask)
165
-
166
- text_ids_neg = []
167
- text_atts_neg = []
168
- for b in range(bs):
169
- neg_idx = torch.multinomial(weights_i2t[b], 1).item()
170
- text_ids_neg.append(input_ids_world[neg_idx])
171
- text_atts_neg.append(att_mask_world[neg_idx])
172
-
173
- else:
174
- with torch.no_grad():
175
- mask = torch.eq(idx, idx.t())
176
-
177
- sim_i2t = image_feat @ text_feat.t() / self.temp
178
- sim_t2i = text_feat @ image_feat.t() / self.temp
179
-
180
- weights_i2t = F.softmax(sim_i2t,dim=1)
181
- weights_i2t.masked_fill_(mask, 0)
182
-
183
- weights_t2i = F.softmax(sim_t2i,dim=1)
184
- weights_t2i.masked_fill_(mask, 0)
185
-
186
- # select a negative image (from same rank) for each text
187
- image_embeds_neg = []
188
- for b in range(bs):
189
- neg_idx = torch.multinomial(weights_t2i[b], 1).item()
190
- image_embeds_neg.append(image_embeds[neg_idx])
191
- image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
192
-
193
- # select a negative text (from same rank) for each image
194
- text_ids_neg = []
195
- text_atts_neg = []
196
- for b in range(bs):
197
- neg_idx = torch.multinomial(weights_i2t[b], 1).item()
198
- text_ids_neg.append(encoder_input_ids[neg_idx])
199
- text_atts_neg.append(text.attention_mask[neg_idx])
200
-
201
- text_ids_neg = torch.stack(text_ids_neg,dim=0)
202
- text_atts_neg = torch.stack(text_atts_neg,dim=0)
203
-
204
- text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
205
- text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
206
-
207
- image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
208
- image_atts_all = torch.cat([image_atts,image_atts],dim=0)
209
-
210
- output_neg = self.text_encoder(text_ids_all,
211
- attention_mask = text_atts_all,
212
- encoder_hidden_states = image_embeds_all,
213
- encoder_attention_mask = image_atts_all,
214
- return_dict = True,
215
- )
216
-
217
-
218
- vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
219
- vl_output = self.itm_head(vl_embeddings)
220
-
221
- itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
222
- dim=0).to(image.device)
223
- loss_itm = F.cross_entropy(vl_output, itm_labels)
224
-
225
- return loss_ita, loss_itm
226
-
227
-
228
- @torch.no_grad()
229
- def copy_params(self):
230
- for model_pair in self.model_pairs:
231
- for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
232
- param_m.data.copy_(param.data) # initialize
233
- param_m.requires_grad = False # not update by gradient
234
-
235
-
236
- @torch.no_grad()
237
- def _momentum_update(self):
238
- for model_pair in self.model_pairs:
239
- for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
240
- param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
241
-
242
-
243
- @torch.no_grad()
244
- def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
245
- # gather keys before updating queue
246
- image_feats = concat_all_gather(image_feat)
247
- text_feats = concat_all_gather(text_feat)
248
-
249
-
250
- batch_size = image_feats.shape[0]
251
-
252
- ptr = int(self.ptr_queue)
253
- assert self.queue_size % batch_size == 0 # for simplicity
254
-
255
- # replace the keys at ptr (dequeue and enqueue)
256
- self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
257
- self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
258
- self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
259
- ptr = (ptr + batch_size) % self.queue_size # move pointer
260
-
261
- self.ptr_queue[0] = ptr
262
-
263
-
264
- def blip_retrieval(pretrained='',**kwargs):
265
- model = BLIP_Retrieval(**kwargs)
266
- if pretrained:
267
- model,msg = load_checkpoint(model,pretrained)
268
- print("missing keys:")
269
- print(msg.missing_keys)
270
- return model
271
-
272
-
273
- @torch.no_grad()
274
- def concat_all_gather(tensor):
275
- """
276
- Performs all_gather operation on the provided tensors.
277
- *** Warning ***: torch.distributed.all_gather has no gradient.
278
- """
279
- tensors_gather = [torch.ones_like(tensor)
280
- for _ in range(torch.distributed.get_world_size())]
281
- torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
282
-
283
- output = torch.cat(tensors_gather, dim=0)
284
- return output
285
-
286
-
287
- class GatherLayer(torch.autograd.Function):
288
- """
289
- Gather tensors from all workers with support for backward propagation:
290
- This implementation does not cut the gradients as torch.distributed.all_gather does.
291
- """
292
-
293
- @staticmethod
294
- def forward(ctx, x):
295
- output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
296
- torch.distributed.all_gather(output, x)
297
- return tuple(output)
298
-
299
- @staticmethod
300
- def backward(ctx, *grads):
301
- all_gradients = torch.stack(grads)
302
- torch.distributed.all_reduce(all_gradients)
303
- return all_gradients[torch.distributed.get_rank()]
304
-
305
-
306
- def all_gather_with_grad(tensors):
307
- """
308
- Performs all_gather operation on the provided tensors.
309
- Graph remains connected for backward grad computation.
310
- """
311
- # Queue the gathered tensors
312
- world_size = torch.distributed.get_world_size()
313
- # There is no need for reduction in the single-proc case
314
- if world_size == 1:
315
- return tensors
316
-
317
- tensor_all = GatherLayer.apply(tensors)
318
-
319
- return torch.cat(tensor_all, dim=0)