Spaces:
Sleeping
Sleeping
Commit
·
45ac234
1
Parent(s):
070b382
refactor: remove fooocus api
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fooocus_api_version.py +0 -1
- fooocusapi/api.py +0 -41
- fooocusapi/args.py +0 -20
- fooocusapi/base_args.py +0 -27
- fooocusapi/configs/default.py +0 -92
- fooocusapi/models/common/base.py +0 -189
- fooocusapi/models/common/image_meta.py +0 -118
- fooocusapi/models/common/requests.py +0 -132
- fooocusapi/models/common/response.py +0 -90
- fooocusapi/models/common/task.py +0 -60
- fooocusapi/models/requests_v1.py +0 -274
- fooocusapi/models/requests_v2.py +0 -50
- fooocusapi/parameters.py +0 -94
- fooocusapi/routes/__init__.py +0 -0
- fooocusapi/routes/generate_v1.py +0 -186
- fooocusapi/routes/generate_v2.py +0 -199
- fooocusapi/routes/query.py +0 -135
- fooocusapi/sql_client.py +0 -269
- fooocusapi/task_queue.py +0 -323
- fooocusapi/utils/api_utils.py +0 -291
- fooocusapi/utils/call_worker.py +0 -97
- fooocusapi/utils/file_utils.py +0 -143
- fooocusapi/utils/img_utils.py +0 -198
- fooocusapi/utils/logger.py +0 -132
- fooocusapi/utils/lora_manager.py +0 -71
- fooocusapi/utils/model_loader.py +0 -46
- fooocusapi/utils/tools.py +0 -159
- fooocusapi/worker.py +0 -1044
- predict.py +0 -316
- repositories/Fooocus/__init__.py +0 -4
- repositories/Fooocus/args_manager.py +0 -55
- repositories/Fooocus/extras/BLIP/configs/bert_config.json +0 -21
- repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml +0 -33
- repositories/Fooocus/extras/BLIP/configs/med_config.json +0 -21
- repositories/Fooocus/extras/BLIP/configs/nlvr.yaml +0 -21
- repositories/Fooocus/extras/BLIP/configs/nocaps.yaml +0 -15
- repositories/Fooocus/extras/BLIP/configs/pretrain.yaml +0 -27
- repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml +0 -34
- repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml +0 -34
- repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml +0 -12
- repositories/Fooocus/extras/BLIP/configs/vqa.yaml +0 -25
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json +0 -23
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json +0 -0
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json +0 -3
- repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt +0 -0
- repositories/Fooocus/extras/BLIP/models/blip.py +0 -239
- repositories/Fooocus/extras/BLIP/models/blip_itm.py +0 -76
- repositories/Fooocus/extras/BLIP/models/blip_nlvr.py +0 -105
- repositories/Fooocus/extras/BLIP/models/blip_pretrain.py +0 -339
- repositories/Fooocus/extras/BLIP/models/blip_retrieval.py +0 -319
fooocus_api_version.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
version = '0.4.1.1'
|
|
|
|
|
|
fooocusapi/api.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Entry for startup fastapi server
|
| 3 |
-
"""
|
| 4 |
-
from fastapi import FastAPI
|
| 5 |
-
from fastapi.staticfiles import StaticFiles
|
| 6 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
-
|
| 8 |
-
import uvicorn
|
| 9 |
-
|
| 10 |
-
from fooocusapi.utils import file_utils
|
| 11 |
-
from fooocusapi.routes.generate_v1 import secure_router as generate_v1
|
| 12 |
-
from fooocusapi.routes.generate_v2 import secure_router as generate_v2
|
| 13 |
-
from fooocusapi.routes.query import secure_router as query
|
| 14 |
-
from mannequin_to_model import secure_router as mannequin_to_model
|
| 15 |
-
|
| 16 |
-
app = FastAPI()
|
| 17 |
-
|
| 18 |
-
app.add_middleware(
|
| 19 |
-
CORSMiddleware,
|
| 20 |
-
allow_origins=["*"], # Allow access from all sources
|
| 21 |
-
allow_credentials=True,
|
| 22 |
-
allow_methods=["*"], # Allow all HTTP methods
|
| 23 |
-
allow_headers=["*"], # Allow all request headers
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
app.mount("/files", StaticFiles(directory=file_utils.output_dir), name="files")
|
| 27 |
-
|
| 28 |
-
app.include_router(query)
|
| 29 |
-
app.include_router(generate_v1)
|
| 30 |
-
app.include_router(generate_v2)
|
| 31 |
-
app.include_router(mannequin_to_model)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def start_app(args):
|
| 35 |
-
"""Start the FastAPI application"""
|
| 36 |
-
file_utils.STATIC_SERVER_BASE = args.base_url + "/files/"
|
| 37 |
-
uvicorn.run(
|
| 38 |
-
app="fooocusapi.api:app",
|
| 39 |
-
host="0.0.0.0",
|
| 40 |
-
port=8000,
|
| 41 |
-
log_level=args.log_level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/args.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Do not modify the import order
|
| 3 |
-
"""
|
| 4 |
-
from fooocusapi.base_args import add_base_args
|
| 5 |
-
import ldm_patched.modules.args_parser as args_parser
|
| 6 |
-
|
| 7 |
-
# Add Fooocus-API args to parser
|
| 8 |
-
add_base_args(args_parser.parser, False)
|
| 9 |
-
|
| 10 |
-
# Apply Fooocus args
|
| 11 |
-
from args_manager import args_parser
|
| 12 |
-
|
| 13 |
-
# Override the port default value
|
| 14 |
-
args_parser.parser.set_defaults(
|
| 15 |
-
port=8888
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
# Execute args parse again
|
| 19 |
-
args_parser.args = args_parser.parser.parse_args()
|
| 20 |
-
args = args_parser.args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/base_args.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
base_args.py
|
| 3 |
-
"""
|
| 4 |
-
from argparse import ArgumentParser
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def add_base_args(parser: ArgumentParser, before_prepared: bool):
|
| 8 |
-
"""
|
| 9 |
-
Add base args for fooocusapi
|
| 10 |
-
Args:
|
| 11 |
-
parser: ArgumentParser
|
| 12 |
-
before_prepared: before prepare environment
|
| 13 |
-
Returns:
|
| 14 |
-
"""
|
| 15 |
-
if before_prepared:
|
| 16 |
-
parser.add_argument("--port", type=int, default=8888, help="Set the listen port, default: 8888")
|
| 17 |
-
|
| 18 |
-
parser.add_argument("--host", type=str, default='127.0.0.1', help="Set the listen host, default: 127.0.0.1")
|
| 19 |
-
parser.add_argument("--base-url", type=str, default=None, help="Set base url for outside visit, default is http://host:port")
|
| 20 |
-
parser.add_argument("--log-level", type=str, default='info', help="Log info for Uvicorn, default: info")
|
| 21 |
-
parser.add_argument("--skip-pip", default=False, action="store_true", help="Skip automatic pip install when setup")
|
| 22 |
-
parser.add_argument("--preload-pipeline", default=False, action="store_true", help="Preload pipeline before start http server")
|
| 23 |
-
parser.add_argument("--queue-size", type=int, default=100, help="Working queue size, default: 100, generation requests exceeding working queue size will return failure")
|
| 24 |
-
parser.add_argument("--queue-history", type=int, default=0, help="Finished jobs reserve size, tasks exceeding the limit will be deleted, including output image files, default: 0, means no limit")
|
| 25 |
-
parser.add_argument('--webhook-url', type=str, default=None, help='The URL to send a POST request when a job is finished')
|
| 26 |
-
parser.add_argument('--persistent', default=False, action="store_true", help="Store history to db")
|
| 27 |
-
parser.add_argument("--apikey", type=str, default=None, help="API key for authenticating requests")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/configs/default.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Static variables for Fooocus API
|
| 3 |
-
"""
|
| 4 |
-
img_generate_responses = {
|
| 5 |
-
"200": {
|
| 6 |
-
"description": "PNG bytes if request's 'Accept' header is 'image/png', otherwise JSON",
|
| 7 |
-
"content": {
|
| 8 |
-
"application/json": {
|
| 9 |
-
"example": [{
|
| 10 |
-
"base64": "...very long string...",
|
| 11 |
-
"seed": "1050625087",
|
| 12 |
-
"finish_reason": "SUCCESS",
|
| 13 |
-
}]
|
| 14 |
-
},
|
| 15 |
-
"application/json async": {
|
| 16 |
-
"example": {
|
| 17 |
-
"job_id": 1,
|
| 18 |
-
"job_type": "Text to Image"
|
| 19 |
-
}
|
| 20 |
-
},
|
| 21 |
-
"image/png": {
|
| 22 |
-
"example": "PNG bytes, what did you expect?"
|
| 23 |
-
},
|
| 24 |
-
},
|
| 25 |
-
}
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
default_inpaint_engine_version = "v2.6"
|
| 29 |
-
|
| 30 |
-
default_styles = ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
|
| 31 |
-
default_base_model_name = "juggernautXL_v8Rundiffusion.safetensors"
|
| 32 |
-
default_refiner_model_name = "None"
|
| 33 |
-
default_refiner_switch = 0.5
|
| 34 |
-
default_loras = [[True, "sd_xl_offset_example-lora_1.0.safetensors", 0.1]]
|
| 35 |
-
default_cfg_scale = 7.0
|
| 36 |
-
default_prompt_negative = ""
|
| 37 |
-
default_aspect_ratio = "1152*896"
|
| 38 |
-
default_sampler = "dpmpp_2m_sde_gpu"
|
| 39 |
-
default_scheduler = "karras"
|
| 40 |
-
|
| 41 |
-
available_aspect_ratios = [
|
| 42 |
-
"704*1408",
|
| 43 |
-
"704*1344",
|
| 44 |
-
"768*1344",
|
| 45 |
-
"768*1280",
|
| 46 |
-
"832*1216",
|
| 47 |
-
"832*1152",
|
| 48 |
-
"896*1152",
|
| 49 |
-
"896*1088",
|
| 50 |
-
"960*1088",
|
| 51 |
-
"960*1024",
|
| 52 |
-
"1024*1024",
|
| 53 |
-
"1024*960",
|
| 54 |
-
"1088*960",
|
| 55 |
-
"1088*896",
|
| 56 |
-
"1152*896",
|
| 57 |
-
"1152*832",
|
| 58 |
-
"1216*832",
|
| 59 |
-
"1280*768",
|
| 60 |
-
"1344*768",
|
| 61 |
-
"1344*704",
|
| 62 |
-
"1408*704",
|
| 63 |
-
"1472*704",
|
| 64 |
-
"1536*640",
|
| 65 |
-
"1600*640",
|
| 66 |
-
"1664*576",
|
| 67 |
-
"1728*576",
|
| 68 |
-
]
|
| 69 |
-
|
| 70 |
-
uov_methods = [
|
| 71 |
-
"Disabled",
|
| 72 |
-
"Vary (Subtle)",
|
| 73 |
-
"Vary (Strong)",
|
| 74 |
-
"Upscale (1.5x)",
|
| 75 |
-
"Upscale (2x)",
|
| 76 |
-
"Upscale (Fast 2x)",
|
| 77 |
-
"Upscale (Custom)",
|
| 78 |
-
]
|
| 79 |
-
|
| 80 |
-
outpaint_expansions = ["Left", "Right", "Top", "Bottom"]
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def get_aspect_ratio_value(label: str) -> str:
|
| 84 |
-
"""
|
| 85 |
-
Get aspect ratio
|
| 86 |
-
Args:
|
| 87 |
-
label: str, aspect ratio
|
| 88 |
-
|
| 89 |
-
Returns:
|
| 90 |
-
|
| 91 |
-
"""
|
| 92 |
-
return label.split(" ")[0].replace("×", "*")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/common/base.py
DELETED
|
@@ -1,189 +0,0 @@
|
|
| 1 |
-
"""Common models"""
|
| 2 |
-
from typing import List, Tuple
|
| 3 |
-
from enum import Enum
|
| 4 |
-
from fastapi import UploadFile
|
| 5 |
-
from fastapi.exceptions import RequestValidationError
|
| 6 |
-
from pydantic import (
|
| 7 |
-
ValidationError,
|
| 8 |
-
ConfigDict,
|
| 9 |
-
BaseModel,
|
| 10 |
-
TypeAdapter,
|
| 11 |
-
Field
|
| 12 |
-
)
|
| 13 |
-
from pydantic_core import InitErrorDetails
|
| 14 |
-
|
| 15 |
-
from fooocusapi.configs.default import default_loras
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class PerformanceSelection(str, Enum):
|
| 19 |
-
"""Performance selection"""
|
| 20 |
-
speed = 'Speed'
|
| 21 |
-
quality = 'Quality'
|
| 22 |
-
extreme_speed = 'Extreme Speed'
|
| 23 |
-
lightning = 'Lightning'
|
| 24 |
-
hyper_sd = 'Hyper-SD'
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class Lora(BaseModel):
|
| 28 |
-
"""Common params lora model"""
|
| 29 |
-
enabled: bool
|
| 30 |
-
model_name: str
|
| 31 |
-
weight: float = Field(default=0.5, ge=-2, le=2)
|
| 32 |
-
|
| 33 |
-
model_config = ConfigDict(
|
| 34 |
-
protected_namespaces=('protect_me_', 'also_protect_')
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
LoraList = TypeAdapter(List[Lora])
|
| 39 |
-
default_loras_model = []
|
| 40 |
-
for lora in default_loras:
|
| 41 |
-
if lora[0] != 'None':
|
| 42 |
-
default_loras_model.append(
|
| 43 |
-
Lora(
|
| 44 |
-
enabled=lora[0],
|
| 45 |
-
model_name=lora[1],
|
| 46 |
-
weight=lora[2])
|
| 47 |
-
)
|
| 48 |
-
default_loras_json = LoraList.dump_json(default_loras_model)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class UpscaleOrVaryMethod(str, Enum):
|
| 52 |
-
"""Upscale or Vary method"""
|
| 53 |
-
subtle_variation = 'Vary (Subtle)'
|
| 54 |
-
strong_variation = 'Vary (Strong)'
|
| 55 |
-
upscale_15 = 'Upscale (1.5x)'
|
| 56 |
-
upscale_2 = 'Upscale (2x)'
|
| 57 |
-
upscale_fast = 'Upscale (Fast 2x)'
|
| 58 |
-
upscale_custom = 'Upscale (Custom)'
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
class OutpaintExpansion(str, Enum):
|
| 62 |
-
"""Outpaint expansion"""
|
| 63 |
-
left = 'Left'
|
| 64 |
-
right = 'Right'
|
| 65 |
-
top = 'Top'
|
| 66 |
-
bottom = 'Bottom'
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class ControlNetType(str, Enum):
|
| 70 |
-
"""ControlNet Type"""
|
| 71 |
-
cn_ip = "ImagePrompt"
|
| 72 |
-
cn_ip_face = "FaceSwap"
|
| 73 |
-
cn_canny = "PyraCanny"
|
| 74 |
-
cn_cpds = "CPDS"
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class ImagePrompt(BaseModel):
|
| 78 |
-
"""Common params object ImagePrompt"""
|
| 79 |
-
cn_img: UploadFile | None = Field(default=None)
|
| 80 |
-
cn_stop: float | None = Field(default=None, ge=0, le=1)
|
| 81 |
-
cn_weight: float | None = Field(default=None, ge=0, le=2, description="None for default value")
|
| 82 |
-
cn_type: ControlNetType = Field(default=ControlNetType.cn_ip)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
class DescribeImageType(str, Enum):
|
| 86 |
-
"""Image type for image to prompt"""
|
| 87 |
-
photo = 'Photo'
|
| 88 |
-
anime = 'Anime'
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class ImageMetaScheme(str, Enum):
|
| 92 |
-
"""Scheme for save image meta
|
| 93 |
-
Attributes:
|
| 94 |
-
Fooocus: json format
|
| 95 |
-
A111: string
|
| 96 |
-
"""
|
| 97 |
-
Fooocus = 'fooocus'
|
| 98 |
-
A111 = 'a111'
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def style_selection_parser(style_selections: str | List[str]) -> List[str]:
|
| 102 |
-
"""
|
| 103 |
-
Parse style selections, Convert to list
|
| 104 |
-
Args:
|
| 105 |
-
style_selections: str, comma separated Fooocus style selections
|
| 106 |
-
e.g. Fooocus V2, Fooocus Enhance, Fooocus Sharp
|
| 107 |
-
Returns:
|
| 108 |
-
List[str]
|
| 109 |
-
"""
|
| 110 |
-
style_selection_arr: List[str] = []
|
| 111 |
-
if style_selections is None or len(style_selections) == 0:
|
| 112 |
-
return []
|
| 113 |
-
for part in style_selections:
|
| 114 |
-
if len(part) > 0:
|
| 115 |
-
for s in part.split(','):
|
| 116 |
-
style = s.strip()
|
| 117 |
-
style_selection_arr.append(style)
|
| 118 |
-
return style_selection_arr
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def lora_parser(loras: str) -> List[Lora]:
|
| 122 |
-
"""
|
| 123 |
-
Parse lora config, Convert to list
|
| 124 |
-
Args:
|
| 125 |
-
loras: a json string for loras
|
| 126 |
-
Returns:
|
| 127 |
-
List[Lora]
|
| 128 |
-
"""
|
| 129 |
-
loras_model: List[Lora] = []
|
| 130 |
-
if loras is None or len(loras) == 0:
|
| 131 |
-
return loras_model
|
| 132 |
-
try:
|
| 133 |
-
loras_model = LoraList.validate_json(loras)
|
| 134 |
-
return loras_model
|
| 135 |
-
except ValidationError as ve:
|
| 136 |
-
errs = ve.errors()
|
| 137 |
-
raise RequestValidationError from errs
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def outpaint_selections_parser(outpaint_selections: str | list[str]) -> List[OutpaintExpansion]:
|
| 141 |
-
"""
|
| 142 |
-
Parse outpaint selections, Convert to list
|
| 143 |
-
Args:
|
| 144 |
-
outpaint_selections: str, comma separated Left, Right, Top, Bottom
|
| 145 |
-
e.g. Left, Right, Top, Bottom
|
| 146 |
-
Returns:
|
| 147 |
-
List[OutpaintExpansion]
|
| 148 |
-
"""
|
| 149 |
-
outpaint_selections_arr: List[OutpaintExpansion] = []
|
| 150 |
-
if outpaint_selections is None or len(outpaint_selections) == 0:
|
| 151 |
-
return []
|
| 152 |
-
for part in outpaint_selections:
|
| 153 |
-
if len(part) > 0:
|
| 154 |
-
for s in part.split(','):
|
| 155 |
-
try:
|
| 156 |
-
expansion = OutpaintExpansion(s)
|
| 157 |
-
outpaint_selections_arr.append(expansion)
|
| 158 |
-
except ValueError:
|
| 159 |
-
errs = InitErrorDetails(
|
| 160 |
-
type='enum',
|
| 161 |
-
loc=tuple('outpaint_selections'),
|
| 162 |
-
input=outpaint_selections,
|
| 163 |
-
ctx={
|
| 164 |
-
'expected': "str, comma separated Left, Right, Top, Bottom"
|
| 165 |
-
})
|
| 166 |
-
raise RequestValidationError from errs
|
| 167 |
-
return outpaint_selections_arr
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def image_prompt_parser(image_prompts_config: List[Tuple]) -> List[ImagePrompt]:
|
| 171 |
-
"""
|
| 172 |
-
Image prompt parser, Convert to List[ImagePrompt]
|
| 173 |
-
Args:
|
| 174 |
-
image_prompts_config: List[Tuple]
|
| 175 |
-
e.g. ('image1.jpg', 0.5, 1.0, 'normal'), ('image2.jpg', 0.5, 1.0, 'normal')
|
| 176 |
-
returns:
|
| 177 |
-
List[ImagePrompt]
|
| 178 |
-
"""
|
| 179 |
-
image_prompts: List[ImagePrompt] = []
|
| 180 |
-
if image_prompts_config is None or len(image_prompts_config) == 0:
|
| 181 |
-
return []
|
| 182 |
-
for config in image_prompts_config:
|
| 183 |
-
cn_img, cn_stop, cn_weight, cn_type = config
|
| 184 |
-
image_prompts.append(ImagePrompt(
|
| 185 |
-
cn_img=cn_img,
|
| 186 |
-
cn_stop=cn_stop,
|
| 187 |
-
cn_weight=cn_weight,
|
| 188 |
-
cn_type=cn_type))
|
| 189 |
-
return image_prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/common/image_meta.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Image meta schema
|
| 3 |
-
"""
|
| 4 |
-
from typing import List
|
| 5 |
-
|
| 6 |
-
from fooocus_version import version
|
| 7 |
-
from pydantic import BaseModel
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class ImageMeta(BaseModel):
|
| 11 |
-
"""
|
| 12 |
-
Image meta data model
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
metadata_scheme: str = "fooocus"
|
| 16 |
-
|
| 17 |
-
base_model: str
|
| 18 |
-
base_model_hash: str
|
| 19 |
-
|
| 20 |
-
prompt: str
|
| 21 |
-
full_prompt: List[str]
|
| 22 |
-
prompt_expansion: str
|
| 23 |
-
|
| 24 |
-
negative_prompt: str
|
| 25 |
-
full_negative_prompt: List[str]
|
| 26 |
-
|
| 27 |
-
performance: str
|
| 28 |
-
|
| 29 |
-
style: str
|
| 30 |
-
|
| 31 |
-
refiner_model: str = "None"
|
| 32 |
-
refiner_switch: float = 0.5
|
| 33 |
-
|
| 34 |
-
loras: List[list]
|
| 35 |
-
|
| 36 |
-
resolution: str
|
| 37 |
-
|
| 38 |
-
sampler: str = "dpmpp_2m_sde_gpu"
|
| 39 |
-
scheduler: str = "karras"
|
| 40 |
-
seed: str
|
| 41 |
-
adm_guidance: str
|
| 42 |
-
guidance_scale: float
|
| 43 |
-
sharpness: float
|
| 44 |
-
steps: int
|
| 45 |
-
vae_name: str
|
| 46 |
-
|
| 47 |
-
version: str = version
|
| 48 |
-
|
| 49 |
-
def __repr__(self):
|
| 50 |
-
return ""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def loras_parser(loras: list) -> list:
|
| 54 |
-
"""
|
| 55 |
-
Parse lora list
|
| 56 |
-
"""
|
| 57 |
-
return [
|
| 58 |
-
[
|
| 59 |
-
lora[0].rsplit('.', maxsplit=1)[:1][0],
|
| 60 |
-
lora[1],
|
| 61 |
-
"hash_not_calculated",
|
| 62 |
-
] for lora in loras if lora[0] != 'None' and lora[0] is not None]
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def image_parse(
|
| 66 |
-
async_tak: object,
|
| 67 |
-
task: dict
|
| 68 |
-
) -> dict | str:
|
| 69 |
-
"""
|
| 70 |
-
Parse image meta data
|
| 71 |
-
Generate meta data for image from task and async task object
|
| 72 |
-
Args:
|
| 73 |
-
async_tak: async task obj
|
| 74 |
-
task: task obj
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
dict: image meta data
|
| 78 |
-
"""
|
| 79 |
-
req_param = async_tak.req_param
|
| 80 |
-
meta = ImageMeta(
|
| 81 |
-
metadata_scheme=req_param.meta_scheme,
|
| 82 |
-
base_model=req_param.base_model_name.rsplit('.', maxsplit=1)[:1][0],
|
| 83 |
-
base_model_hash='',
|
| 84 |
-
prompt=req_param.prompt,
|
| 85 |
-
full_prompt=task['positive'],
|
| 86 |
-
prompt_expansion=task['expansion'],
|
| 87 |
-
negative_prompt=req_param.negative_prompt,
|
| 88 |
-
full_negative_prompt=task['negative'],
|
| 89 |
-
performance=req_param.performance_selection,
|
| 90 |
-
style=str(req_param.style_selections),
|
| 91 |
-
refiner_model=req_param.refiner_model_name,
|
| 92 |
-
refiner_switch=req_param.refiner_switch,
|
| 93 |
-
loras=loras_parser(req_param.loras),
|
| 94 |
-
resolution=str(tuple([int(n) for n in req_param.aspect_ratios_selection.split('*')])),
|
| 95 |
-
sampler=req_param.advanced_params.sampler_name,
|
| 96 |
-
scheduler=req_param.advanced_params.scheduler_name,
|
| 97 |
-
seed=str(task['task_seed']),
|
| 98 |
-
adm_guidance=str((
|
| 99 |
-
req_param.advanced_params.adm_scaler_positive,
|
| 100 |
-
req_param.advanced_params.adm_scaler_negative,
|
| 101 |
-
req_param.advanced_params.adm_scaler_end)),
|
| 102 |
-
guidance_scale=req_param.guidance_scale,
|
| 103 |
-
sharpness=req_param.sharpness,
|
| 104 |
-
steps=-1,
|
| 105 |
-
vae_name=req_param.advanced_params.vae_name,
|
| 106 |
-
version=version
|
| 107 |
-
)
|
| 108 |
-
if meta.metadata_scheme not in ["fooocus", "a111"]:
|
| 109 |
-
meta.metadata_scheme = "fooocus"
|
| 110 |
-
if meta.metadata_scheme == "fooocus":
|
| 111 |
-
meta_dict = meta.model_dump()
|
| 112 |
-
for i, lora in enumerate(meta.loras):
|
| 113 |
-
attr_name = f"lora_combined_{i+1}"
|
| 114 |
-
lr = [str(x) for x in lora]
|
| 115 |
-
meta_dict[attr_name] = f"{lr[0]} : {lr[1]}"
|
| 116 |
-
else:
|
| 117 |
-
meta_dict = meta.model_dump()
|
| 118 |
-
return meta_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/common/requests.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
"""Common model for requests"""
|
| 2 |
-
from typing import List
|
| 3 |
-
from pydantic import (
|
| 4 |
-
BaseModel,
|
| 5 |
-
Field,
|
| 6 |
-
ValidationError
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
from modules.config import (
|
| 10 |
-
default_sampler,
|
| 11 |
-
default_scheduler,
|
| 12 |
-
default_prompt,
|
| 13 |
-
default_prompt_negative,
|
| 14 |
-
default_aspect_ratio,
|
| 15 |
-
default_base_model_name,
|
| 16 |
-
default_refiner_model_name,
|
| 17 |
-
default_refiner_switch,
|
| 18 |
-
default_cfg_scale,
|
| 19 |
-
default_styles,
|
| 20 |
-
default_overwrite_step,
|
| 21 |
-
default_inpaint_engine_version,
|
| 22 |
-
default_overwrite_switch,
|
| 23 |
-
default_cfg_tsnr,
|
| 24 |
-
default_sample_sharpness,
|
| 25 |
-
default_vae,
|
| 26 |
-
default_clip_skip
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
from modules.flags import clip_skip_max
|
| 30 |
-
|
| 31 |
-
from fooocusapi.models.common.base import (
|
| 32 |
-
PerformanceSelection,
|
| 33 |
-
Lora,
|
| 34 |
-
default_loras_model
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
default_aspect_ratio = default_aspect_ratio.split(" ")[0].replace("×", "*")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class QueryJobRequest(BaseModel):
|
| 41 |
-
"""Query job request"""
|
| 42 |
-
job_id: str = Field(description="Job ID to query")
|
| 43 |
-
require_step_preview: bool = Field(
|
| 44 |
-
default=False,
|
| 45 |
-
description="Set to true will return preview image of generation steps at current time")
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class AdvancedParams(BaseModel):
|
| 49 |
-
"""Common params object AdvancedParams"""
|
| 50 |
-
disable_preview: bool = Field(False, description="Disable preview during generation")
|
| 51 |
-
disable_intermediate_results: bool = Field(False, description="Disable intermediate results")
|
| 52 |
-
disable_seed_increment: bool = Field(False, description="Disable Seed Increment")
|
| 53 |
-
adm_scaler_positive: float = Field(1.5, description="Positive ADM Guidance Scaler", ge=0.1, le=3.0)
|
| 54 |
-
adm_scaler_negative: float = Field(0.8, description="Negative ADM Guidance Scaler", ge=0.1, le=3.0)
|
| 55 |
-
adm_scaler_end: float = Field(0.3, description="ADM Guidance End At Step", ge=0.0, le=1.0)
|
| 56 |
-
adaptive_cfg: float = Field(default_cfg_tsnr, description="CFG Mimicking from TSNR", ge=1.0, le=30.0)
|
| 57 |
-
clip_skip: int = Field(default_clip_skip, description="Clip Skip", ge=1, le=clip_skip_max)
|
| 58 |
-
sampler_name: str = Field(default_sampler, description="Sampler")
|
| 59 |
-
scheduler_name: str = Field(default_scheduler, description="Scheduler")
|
| 60 |
-
overwrite_step: int = Field(default_overwrite_step, description="Forced Overwrite of Sampling Step", ge=-1, le=200)
|
| 61 |
-
overwrite_switch: float = Field(default_overwrite_switch, description="Forced Overwrite of Refiner Switch Step", ge=-1, le=1)
|
| 62 |
-
overwrite_width: int = Field(-1, description="Forced Overwrite of Generating Width", ge=-1, le=2048)
|
| 63 |
-
overwrite_height: int = Field(-1, description="Forced Overwrite of Generating Height", ge=-1, le=2048)
|
| 64 |
-
overwrite_vary_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Vary"', ge=-1, le=1.0)
|
| 65 |
-
overwrite_upscale_strength: float = Field(-1, description='Forced Overwrite of Denoising Strength of "Upscale"', ge=-1, le=1.0)
|
| 66 |
-
mixing_image_prompt_and_vary_upscale: bool = Field(False, description="Mixing Image Prompt and Vary/Upscale")
|
| 67 |
-
mixing_image_prompt_and_inpaint: bool = Field(False, description="Mixing Image Prompt and Inpaint")
|
| 68 |
-
debugging_cn_preprocessor: bool = Field(False, description="Debug Preprocessors")
|
| 69 |
-
skipping_cn_preprocessor: bool = Field(False, description="Skip Preprocessors")
|
| 70 |
-
canny_low_threshold: int = Field(64, description="Canny Low Threshold", ge=1, le=255)
|
| 71 |
-
canny_high_threshold: int = Field(128, description="Canny High Threshold", ge=1, le=255)
|
| 72 |
-
refiner_swap_method: str = Field('joint', description="Refiner swap method")
|
| 73 |
-
controlnet_softness: float = Field(0.25, description="Softness of ControlNet", ge=0.0, le=1.0)
|
| 74 |
-
freeu_enabled: bool = Field(False, description="FreeU enabled")
|
| 75 |
-
freeu_b1: float = Field(1.01, description="FreeU B1")
|
| 76 |
-
freeu_b2: float = Field(1.02, description="FreeU B2")
|
| 77 |
-
freeu_s1: float = Field(0.99, description="FreeU B3")
|
| 78 |
-
freeu_s2: float = Field(0.95, description="FreeU B4")
|
| 79 |
-
debugging_inpaint_preprocessor: bool = Field(False, description="Debug Inpaint Preprocessing")
|
| 80 |
-
inpaint_disable_initial_latent: bool = Field(False, description="Disable initial latent in inpaint")
|
| 81 |
-
inpaint_engine: str = Field(default_inpaint_engine_version, description="Inpaint Engine")
|
| 82 |
-
inpaint_strength: float = Field(1.0, description="Inpaint Denoising Strength", ge=0.0, le=1.0)
|
| 83 |
-
inpaint_respective_field: float = Field(1.0, description="Inpaint Respective Field", ge=0.0, le=1.0)
|
| 84 |
-
inpaint_mask_upload_checkbox: bool = Field(False, description="Upload Mask")
|
| 85 |
-
invert_mask_checkbox: bool = Field(False, description="Invert Mask")
|
| 86 |
-
inpaint_erode_or_dilate: int = Field(0, description="Mask Erode or Dilate", ge=-64, le=64)
|
| 87 |
-
black_out_nsfw: bool = Field(False, description="Block out NSFW")
|
| 88 |
-
vae_name: str = Field(default_vae, description="VAE name")
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class CommonRequest(BaseModel):
|
| 92 |
-
"""All generate request based on this model"""
|
| 93 |
-
prompt: str = default_prompt
|
| 94 |
-
negative_prompt: str = default_prompt_negative
|
| 95 |
-
style_selections: List[str] = default_styles
|
| 96 |
-
performance_selection: PerformanceSelection = PerformanceSelection.speed
|
| 97 |
-
aspect_ratios_selection: str = default_aspect_ratio
|
| 98 |
-
image_number: int = Field(default=1, description="Image number", ge=1, le=32)
|
| 99 |
-
image_seed: int = Field(default=-1, description="Seed to generate image, -1 for random")
|
| 100 |
-
sharpness: float = Field(default=default_sample_sharpness, ge=0.0, le=30.0)
|
| 101 |
-
guidance_scale: float = Field(default=default_cfg_scale, ge=1.0, le=30.0)
|
| 102 |
-
base_model_name: str = default_base_model_name
|
| 103 |
-
refiner_model_name: str = default_refiner_model_name
|
| 104 |
-
refiner_switch: float = Field(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0)
|
| 105 |
-
loras: List[Lora] = Field(default=default_loras_model)
|
| 106 |
-
advanced_params: AdvancedParams = AdvancedParams()
|
| 107 |
-
save_meta: bool = Field(default=True, description="Save meta data")
|
| 108 |
-
meta_scheme: str = Field(default='fooocus', description="Meta data scheme, one of [fooocus, a111]")
|
| 109 |
-
save_extension: str = Field(default='png', description="Save extension, one of [png, jpg, webp]")
|
| 110 |
-
save_name: str = Field(default='', description="Image name for output image, default is job id + seq")
|
| 111 |
-
read_wildcards_in_order: bool = Field(default=False, description="Read wildcards in order")
|
| 112 |
-
require_base64: bool = Field(default=False, description="Return base64 data of generated image")
|
| 113 |
-
async_process: bool = Field(default=False, description="Set to true will run async and return job info for retrieve generation result later")
|
| 114 |
-
webhook_url: str | None = Field(default='', description="Optional URL for a webhook callback. If provided, the system will send a POST request to this URL upon task completion or failure."
|
| 115 |
-
" This allows for asynchronous notification of task status.")
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def advanced_params_parser(advanced_params: str | None) -> AdvancedParams:
|
| 119 |
-
"""
|
| 120 |
-
Parse advanced params, Convert to AdvancedParams
|
| 121 |
-
Args:
|
| 122 |
-
advanced_params: str, json format
|
| 123 |
-
Returns:
|
| 124 |
-
AdvancedParams object, if validate error return default value
|
| 125 |
-
"""
|
| 126 |
-
if advanced_params is not None and len(advanced_params) > 0:
|
| 127 |
-
try:
|
| 128 |
-
advanced_params_obj = AdvancedParams.__pydantic_validator__.validate_json(advanced_params)
|
| 129 |
-
return AdvancedParams(**advanced_params_obj)
|
| 130 |
-
except ValidationError:
|
| 131 |
-
return AdvancedParams()
|
| 132 |
-
return AdvancedParams()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/common/response.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
"""Fooocus API models for response"""
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
-
from pydantic import (
|
| 5 |
-
BaseModel,
|
| 6 |
-
ConfigDict,
|
| 7 |
-
Field
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
from fooocusapi.models.common.task import (
|
| 11 |
-
GeneratedImageResult,
|
| 12 |
-
AsyncJobStage
|
| 13 |
-
)
|
| 14 |
-
from fooocusapi.task_queue import TaskType
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class DescribeImageResponse(BaseModel):
|
| 18 |
-
"""
|
| 19 |
-
describe image response
|
| 20 |
-
"""
|
| 21 |
-
describe: str
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class AsyncJobResponse(BaseModel):
|
| 25 |
-
"""
|
| 26 |
-
Async job response
|
| 27 |
-
Attributes:
|
| 28 |
-
job_id: Job ID
|
| 29 |
-
job_type: Job type
|
| 30 |
-
job_stage: Job stage
|
| 31 |
-
job_progress: Job progress, 0-100
|
| 32 |
-
job_status: Job status
|
| 33 |
-
job_step_preview: Job step preview
|
| 34 |
-
job_result: Job result
|
| 35 |
-
"""
|
| 36 |
-
job_id: str = Field(description="Job ID")
|
| 37 |
-
job_type: TaskType = Field(description="Job type")
|
| 38 |
-
job_stage: AsyncJobStage = Field(description="Job running stage")
|
| 39 |
-
job_progress: int = Field(description="Job running progress, 100 is for finished.")
|
| 40 |
-
job_status: str | None = Field(None, description="Job running status in text")
|
| 41 |
-
job_step_preview: str | None = Field(None, description="Preview image of generation steps at current time, as base64 image")
|
| 42 |
-
job_result: List[GeneratedImageResult] | None = Field(None, description="Job generation result")
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class JobQueueInfo(BaseModel):
|
| 46 |
-
"""
|
| 47 |
-
job queue info
|
| 48 |
-
Attributes:
|
| 49 |
-
running_size: int, The current running and waiting job count
|
| 50 |
-
finished_size: int, The current finished job count
|
| 51 |
-
last_job_id: str, Last submit generation job id
|
| 52 |
-
"""
|
| 53 |
-
running_size: int = Field(description="The current running and waiting job count")
|
| 54 |
-
finished_size: int = Field(description="Finished job count (after auto clean)")
|
| 55 |
-
last_job_id: str | None = Field(description="Last submit generation job id")
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# TODO May need more detail fields, will add later when someone need
|
| 59 |
-
class JobHistoryInfo(BaseModel):
|
| 60 |
-
"""
|
| 61 |
-
job history info
|
| 62 |
-
"""
|
| 63 |
-
job_id: str
|
| 64 |
-
is_finished: bool = False
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
# Response model for the historical tasks
|
| 68 |
-
class JobHistoryResponse(BaseModel):
|
| 69 |
-
"""
|
| 70 |
-
job history response
|
| 71 |
-
"""
|
| 72 |
-
queue: List[JobHistoryInfo] = []
|
| 73 |
-
history: List[JobHistoryInfo] = []
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
class AllModelNamesResponse(BaseModel):
|
| 77 |
-
"""
|
| 78 |
-
all model list response
|
| 79 |
-
"""
|
| 80 |
-
model_filenames: List[str] = Field(description="All available model filenames")
|
| 81 |
-
lora_filenames: List[str] = Field(description="All available lora filenames")
|
| 82 |
-
|
| 83 |
-
model_config = ConfigDict(
|
| 84 |
-
protected_namespaces=('protect_me_', 'also_protect_')
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
class StopResponse(BaseModel):
|
| 89 |
-
"""stop task response"""
|
| 90 |
-
msg: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/common/task.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Task and job related models
|
| 3 |
-
"""
|
| 4 |
-
from enum import Enum
|
| 5 |
-
from pydantic import (
|
| 6 |
-
BaseModel,
|
| 7 |
-
Field
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class TaskType(str, Enum):
|
| 12 |
-
"""
|
| 13 |
-
Task type object
|
| 14 |
-
"""
|
| 15 |
-
text_2_img = 'Text to Image'
|
| 16 |
-
img_uov = 'Image Upscale or Variation'
|
| 17 |
-
img_inpaint_outpaint = 'Image Inpaint or Outpaint'
|
| 18 |
-
img_prompt = 'Image Prompt'
|
| 19 |
-
not_found = 'Not Found'
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class GenerationFinishReason(str, Enum):
|
| 23 |
-
"""
|
| 24 |
-
Generation finish reason
|
| 25 |
-
"""
|
| 26 |
-
success = 'SUCCESS'
|
| 27 |
-
queue_is_full = 'QUEUE_IS_FULL'
|
| 28 |
-
user_cancel = 'USER_CANCEL'
|
| 29 |
-
error = 'ERROR'
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class ImageGenerationResult:
|
| 33 |
-
"""
|
| 34 |
-
Image generation result
|
| 35 |
-
"""
|
| 36 |
-
def __init__(self, im: str | None, seed: str, finish_reason: GenerationFinishReason):
|
| 37 |
-
self.im = im
|
| 38 |
-
self.seed = seed
|
| 39 |
-
self.finish_reason = finish_reason
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class AsyncJobStage(str, Enum):
|
| 43 |
-
"""
|
| 44 |
-
Async job stage
|
| 45 |
-
"""
|
| 46 |
-
waiting = 'WAITING'
|
| 47 |
-
running = 'RUNNING'
|
| 48 |
-
success = 'SUCCESS'
|
| 49 |
-
error = 'ERROR'
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class GeneratedImageResult(BaseModel):
|
| 53 |
-
"""
|
| 54 |
-
Generated images result
|
| 55 |
-
"""
|
| 56 |
-
base64: str | None = Field(
|
| 57 |
-
description="Image encoded in base64, or null if finishReason is not 'SUCCESS', only return when request require base64")
|
| 58 |
-
url: str | None = Field(description="Image file static serve url, or null if finishReason is not 'SUCCESS'")
|
| 59 |
-
seed: str = Field(description="The seed associated with this image")
|
| 60 |
-
finish_reason: GenerationFinishReason
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/requests_v1.py
DELETED
|
@@ -1,274 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
requests models for v1 endpoints
|
| 3 |
-
"""
|
| 4 |
-
from typing import List
|
| 5 |
-
from fastapi.params import File
|
| 6 |
-
from fastapi import (
|
| 7 |
-
UploadFile,
|
| 8 |
-
Form
|
| 9 |
-
)
|
| 10 |
-
from fooocusapi.models.common.requests import (
|
| 11 |
-
CommonRequest,
|
| 12 |
-
advanced_params_parser
|
| 13 |
-
)
|
| 14 |
-
from fooocusapi.models.common.base import (
|
| 15 |
-
ImagePrompt,
|
| 16 |
-
ControlNetType,
|
| 17 |
-
OutpaintExpansion,
|
| 18 |
-
UpscaleOrVaryMethod,
|
| 19 |
-
PerformanceSelection
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
from fooocusapi.models.common.base import (
|
| 23 |
-
style_selection_parser,
|
| 24 |
-
lora_parser,
|
| 25 |
-
outpaint_selections_parser,
|
| 26 |
-
image_prompt_parser,
|
| 27 |
-
default_loras_json
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
from fooocusapi.configs.default import (
|
| 31 |
-
default_prompt_negative,
|
| 32 |
-
default_aspect_ratio,
|
| 33 |
-
default_base_model_name,
|
| 34 |
-
default_refiner_model_name,
|
| 35 |
-
default_refiner_switch,
|
| 36 |
-
default_cfg_scale,
|
| 37 |
-
default_styles,
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class ImgUpscaleOrVaryRequest(CommonRequest):
|
| 42 |
-
"""
|
| 43 |
-
Request for image upscale or variation
|
| 44 |
-
Attributes:
|
| 45 |
-
input_image: Input image
|
| 46 |
-
uov_method: Upscale or variation method
|
| 47 |
-
upscale_value: upscale value
|
| 48 |
-
Functions:
|
| 49 |
-
as_form: Convert request to form data
|
| 50 |
-
"""
|
| 51 |
-
input_image: UploadFile
|
| 52 |
-
uov_method: UpscaleOrVaryMethod
|
| 53 |
-
upscale_value: float | None
|
| 54 |
-
|
| 55 |
-
@classmethod
|
| 56 |
-
def as_form(
|
| 57 |
-
cls,
|
| 58 |
-
input_image: UploadFile = Form(description="Init image for upscale or outpaint"),
|
| 59 |
-
uov_method: UpscaleOrVaryMethod = Form(),
|
| 60 |
-
upscale_value: float | None = Form(None, description="Upscale custom value, None for default value", ge=1.0, le=5.0),
|
| 61 |
-
prompt: str = Form(''),
|
| 62 |
-
negative_prompt: str = Form(default_prompt_negative),
|
| 63 |
-
style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
|
| 64 |
-
performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
|
| 65 |
-
aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
|
| 66 |
-
image_number: int = Form(default=1, description="Image number", ge=1, le=32),
|
| 67 |
-
image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
|
| 68 |
-
sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
|
| 69 |
-
guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
|
| 70 |
-
base_model_name: str = Form(default_base_model_name, description="checkpoint file name"),
|
| 71 |
-
refiner_model_name: str = Form(default_refiner_model_name, description="refiner file name"),
|
| 72 |
-
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
|
| 73 |
-
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
|
| 74 |
-
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
|
| 75 |
-
save_meta: bool = Form(default=False, description="Save metadata to image"),
|
| 76 |
-
meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
|
| 77 |
-
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
|
| 78 |
-
save_name: str = Form(default="", description="Save name, empty for auto generate"),
|
| 79 |
-
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
|
| 80 |
-
read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
|
| 81 |
-
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
|
| 82 |
-
webhook_url: str = Form(default="", description="Webhook url for generation result"),
|
| 83 |
-
):
|
| 84 |
-
style_selection_arr = style_selection_parser(style_selections)
|
| 85 |
-
loras_model = lora_parser(loras)
|
| 86 |
-
advanced_params_obj = advanced_params_parser(advanced_params)
|
| 87 |
-
|
| 88 |
-
return cls(
|
| 89 |
-
input_image=input_image, uov_method=uov_method, upscale_value=upscale_value,
|
| 90 |
-
prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
|
| 91 |
-
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
|
| 92 |
-
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
|
| 93 |
-
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
|
| 94 |
-
loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
|
| 95 |
-
save_extension=save_extension, save_name=save_name, require_base64=require_base64,
|
| 96 |
-
read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
class ImgInpaintOrOutpaintRequest(CommonRequest):
|
| 100 |
-
"""
|
| 101 |
-
Image Inpaint or Outpaint Request
|
| 102 |
-
"""
|
| 103 |
-
input_image: UploadFile | None
|
| 104 |
-
input_mask: UploadFile | None
|
| 105 |
-
inpaint_additional_prompt: str | None
|
| 106 |
-
outpaint_selections: List[OutpaintExpansion]
|
| 107 |
-
outpaint_distance_left: int
|
| 108 |
-
outpaint_distance_right: int
|
| 109 |
-
outpaint_distance_top: int
|
| 110 |
-
outpaint_distance_bottom: int
|
| 111 |
-
|
| 112 |
-
@classmethod
|
| 113 |
-
def as_form(
|
| 114 |
-
cls,
|
| 115 |
-
input_image: UploadFile = Form(description="Init image for inpaint or outpaint"),
|
| 116 |
-
input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
|
| 117 |
-
inpaint_additional_prompt: str | None = Form("", description="Describe what you want to inpaint"),
|
| 118 |
-
outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
|
| 119 |
-
outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, -1 for default"),
|
| 120 |
-
outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, -1 for default"),
|
| 121 |
-
outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, -1 for default"),
|
| 122 |
-
outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, -1 for default"),
|
| 123 |
-
prompt: str = Form(''),
|
| 124 |
-
negative_prompt: str = Form(default_prompt_negative),
|
| 125 |
-
style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
|
| 126 |
-
performance_selection: PerformanceSelection = Form(PerformanceSelection.speed, description="Performance Selection, one of 'Speed','Quality','Extreme Speed'"),
|
| 127 |
-
aspect_ratios_selection: str = Form(default_aspect_ratio, description="Aspect Ratios Selection, default 1152*896"),
|
| 128 |
-
image_number: int = Form(default=1, description="Image number", ge=1, le=32),
|
| 129 |
-
image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
|
| 130 |
-
sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
|
| 131 |
-
guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
|
| 132 |
-
base_model_name: str = Form(default_base_model_name),
|
| 133 |
-
refiner_model_name: str = Form(default_refiner_model_name),
|
| 134 |
-
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
|
| 135 |
-
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
|
| 136 |
-
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
|
| 137 |
-
save_meta: bool = Form(default=False, description="Save metadata to image"),
|
| 138 |
-
meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
|
| 139 |
-
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
|
| 140 |
-
save_name: str = Form(default="", description="Save name, empty for auto generate"),
|
| 141 |
-
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
|
| 142 |
-
read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
|
| 143 |
-
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
|
| 144 |
-
webhook_url: str = Form(default="", description="Webhook url for generation result"),
|
| 145 |
-
):
|
| 146 |
-
if isinstance(input_mask, File):
|
| 147 |
-
input_mask = None
|
| 148 |
-
|
| 149 |
-
outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
|
| 150 |
-
style_selection_arr = style_selection_parser(style_selections)
|
| 151 |
-
loras_model = lora_parser(loras)
|
| 152 |
-
advanced_params_obj = advanced_params_parser(advanced_params)
|
| 153 |
-
|
| 154 |
-
return cls(
|
| 155 |
-
input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt,
|
| 156 |
-
outpaint_selections=outpaint_selections_arr, outpaint_distance_left=outpaint_distance_left,
|
| 157 |
-
outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top,
|
| 158 |
-
outpaint_distance_bottom=outpaint_distance_bottom, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
|
| 159 |
-
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
|
| 160 |
-
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
|
| 161 |
-
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
|
| 162 |
-
loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
|
| 163 |
-
save_extension=save_extension, save_name=save_name, require_base64=require_base64,
|
| 164 |
-
read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
class ImgPromptRequest(ImgInpaintOrOutpaintRequest):
|
| 168 |
-
"""
|
| 169 |
-
Image Prompt Request
|
| 170 |
-
"""
|
| 171 |
-
image_prompts: List[ImagePrompt]
|
| 172 |
-
|
| 173 |
-
@classmethod
|
| 174 |
-
def as_form(
|
| 175 |
-
cls,
|
| 176 |
-
input_image: UploadFile = Form(File(None), description="Init image for inpaint or outpaint"),
|
| 177 |
-
input_mask: UploadFile = Form(File(None), description="Inpaint or outpaint mask"),
|
| 178 |
-
inpaint_additional_prompt: str | None = Form(None, description="Describe what you want to inpaint"),
|
| 179 |
-
outpaint_selections: List[str] = Form([], description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
|
| 180 |
-
outpaint_distance_left: int = Form(default=0, description="Set outpaint left distance, 0 for default"),
|
| 181 |
-
outpaint_distance_right: int = Form(default=0, description="Set outpaint right distance, 0 for default"),
|
| 182 |
-
outpaint_distance_top: int = Form(default=0, description="Set outpaint top distance, 0 for default"),
|
| 183 |
-
outpaint_distance_bottom: int = Form(default=0, description="Set outpaint bottom distance, 0 for default"),
|
| 184 |
-
cn_img1: UploadFile = Form(File(None), description="Input image for image prompt"),
|
| 185 |
-
cn_stop1: float | None = Form(
|
| 186 |
-
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
| 187 |
-
cn_weight1: float | None = Form(
|
| 188 |
-
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
| 189 |
-
cn_type1: ControlNetType = Form(
|
| 190 |
-
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
| 191 |
-
cn_img2: UploadFile = Form(
|
| 192 |
-
File(None), description="Input image for image prompt"),
|
| 193 |
-
cn_stop2: float | None = Form(
|
| 194 |
-
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
| 195 |
-
cn_weight2: float | None = Form(
|
| 196 |
-
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
| 197 |
-
cn_type2: ControlNetType = Form(
|
| 198 |
-
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
| 199 |
-
cn_img3: UploadFile = Form(
|
| 200 |
-
File(None), description="Input image for image prompt"),
|
| 201 |
-
cn_stop3: float | None = Form(
|
| 202 |
-
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
| 203 |
-
cn_weight3: float | None = Form(
|
| 204 |
-
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
| 205 |
-
cn_type3: ControlNetType = Form(
|
| 206 |
-
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
| 207 |
-
cn_img4: UploadFile = Form(
|
| 208 |
-
File(None), description="Input image for image prompt"),
|
| 209 |
-
cn_stop4: float | None = Form(
|
| 210 |
-
default=None, ge=0, le=1, description="Stop at for image prompt, None for default value"),
|
| 211 |
-
cn_weight4: float | None = Form(
|
| 212 |
-
default=None, ge=0, le=2, description="Weight for image prompt, None for default value"),
|
| 213 |
-
cn_type4: ControlNetType = Form(
|
| 214 |
-
default=ControlNetType.cn_ip, description="ControlNet type for image prompt"),
|
| 215 |
-
prompt: str = Form(''),
|
| 216 |
-
negative_prompt: str = Form(default_prompt_negative),
|
| 217 |
-
style_selections: List[str] = Form(default_styles, description="Fooocus style selections, separated by comma"),
|
| 218 |
-
performance_selection: PerformanceSelection = Form(
|
| 219 |
-
PerformanceSelection.speed),
|
| 220 |
-
aspect_ratios_selection: str = Form(default_aspect_ratio),
|
| 221 |
-
image_number: int = Form(
|
| 222 |
-
default=1, description="Image number", ge=1, le=32),
|
| 223 |
-
image_seed: int = Form(default=-1, description="Seed to generate image, -1 for random"),
|
| 224 |
-
sharpness: float = Form(default=2.0, ge=0.0, le=30.0),
|
| 225 |
-
guidance_scale: float = Form(default=default_cfg_scale, ge=1.0, le=30.0),
|
| 226 |
-
base_model_name: str = Form(default_base_model_name),
|
| 227 |
-
refiner_model_name: str = Form(default_refiner_model_name),
|
| 228 |
-
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
|
| 229 |
-
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
|
| 230 |
-
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
|
| 231 |
-
save_meta: bool = Form(default=False, description="Save metadata to image"),
|
| 232 |
-
meta_scheme: str = Form(default='fooocus', description="Metadata scheme, one of 'fooocus', 'a111'"),
|
| 233 |
-
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
|
| 234 |
-
save_name: str = Form(default="", description="Save name, empty for auto generate"),
|
| 235 |
-
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
|
| 236 |
-
read_wildcards_in_order: bool = Form(default=False, description="Read wildcards in order"),
|
| 237 |
-
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generation result later"),
|
| 238 |
-
webhook_url: str = Form(default="", description="Webhook url for generation result"),
|
| 239 |
-
):
|
| 240 |
-
if isinstance(input_image, File):
|
| 241 |
-
input_image = None
|
| 242 |
-
if isinstance(input_mask, File):
|
| 243 |
-
input_mask = None
|
| 244 |
-
if isinstance(cn_img1, File):
|
| 245 |
-
cn_img1 = None
|
| 246 |
-
if isinstance(cn_img2, File):
|
| 247 |
-
cn_img2 = None
|
| 248 |
-
if isinstance(cn_img3, File):
|
| 249 |
-
cn_img3 = None
|
| 250 |
-
if isinstance(cn_img4, File):
|
| 251 |
-
cn_img4 = None
|
| 252 |
-
|
| 253 |
-
outpaint_selections_arr = outpaint_selections_parser(outpaint_selections)
|
| 254 |
-
|
| 255 |
-
image_prompt_config = [
|
| 256 |
-
(cn_img1, cn_stop1, cn_weight1, cn_type1),
|
| 257 |
-
(cn_img2, cn_stop2, cn_weight2, cn_type2),
|
| 258 |
-
(cn_img3, cn_stop3, cn_weight3, cn_type3),
|
| 259 |
-
(cn_img4, cn_stop4, cn_weight4, cn_type4)]
|
| 260 |
-
image_prompts = image_prompt_parser(image_prompt_config)
|
| 261 |
-
style_selection_arr = style_selection_parser(style_selections)
|
| 262 |
-
loras_model = lora_parser(loras)
|
| 263 |
-
advanced_params_obj = advanced_params_parser(advanced_params)
|
| 264 |
-
|
| 265 |
-
return cls(
|
| 266 |
-
input_image=input_image, input_mask=input_mask, inpaint_additional_prompt=inpaint_additional_prompt, outpaint_selections=outpaint_selections_arr,
|
| 267 |
-
outpaint_distance_left=outpaint_distance_left, outpaint_distance_right=outpaint_distance_right, outpaint_distance_top=outpaint_distance_top, outpaint_distance_bottom=outpaint_distance_bottom,
|
| 268 |
-
image_prompts=image_prompts, prompt=prompt, negative_prompt=negative_prompt, style_selections=style_selection_arr,
|
| 269 |
-
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
|
| 270 |
-
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
|
| 271 |
-
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
|
| 272 |
-
loras=loras_model, advanced_params=advanced_params_obj, save_meta=save_meta, meta_scheme=meta_scheme,
|
| 273 |
-
save_extension=save_extension, save_name=save_name, require_base64=require_base64,
|
| 274 |
-
read_wildcards_in_order=read_wildcards_in_order, async_process=async_process, webhook_url=webhook_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/models/requests_v2.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
"""V2 API models"""
|
| 2 |
-
from typing import List
|
| 3 |
-
from pydantic import BaseModel, Field
|
| 4 |
-
from fooocusapi.models.common.requests import CommonRequest
|
| 5 |
-
from fooocusapi.models.common.base import (
|
| 6 |
-
ControlNetType,
|
| 7 |
-
OutpaintExpansion,
|
| 8 |
-
ImagePrompt,
|
| 9 |
-
UpscaleOrVaryMethod
|
| 10 |
-
)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class ImagePromptJson(BaseModel):
|
| 14 |
-
"""Image prompt for V2 API"""
|
| 15 |
-
cn_img: str | None = Field(None, description="Input image for image prompt as base64")
|
| 16 |
-
cn_stop: float | None = Field(0, ge=0, le=1, description="Stop at for image prompt, 0 for default value")
|
| 17 |
-
cn_weight: float | None = Field(0, ge=0, le=2, description="Weight for image prompt, 0 for default value")
|
| 18 |
-
cn_type: ControlNetType = Field(default=ControlNetType.cn_ip, description="ControlNet type for image prompt")
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class ImgInpaintOrOutpaintRequestJson(CommonRequest):
|
| 22 |
-
"""image inpaint or outpaint request"""
|
| 23 |
-
input_image: str = Field('', description="Init image for inpaint or outpaint as base64")
|
| 24 |
-
input_mask: str | None = Field('', description="Inpaint or outpaint mask as base64")
|
| 25 |
-
inpaint_additional_prompt: str | None = Field('', description="Describe what you want to inpaint")
|
| 26 |
-
outpaint_selections: List[OutpaintExpansion] = []
|
| 27 |
-
outpaint_distance_left: int | None = Field(-1, description="Set outpaint left distance")
|
| 28 |
-
outpaint_distance_right: int | None = Field(-1, description="Set outpaint right distance")
|
| 29 |
-
outpaint_distance_top: int | None = Field(-1, description="Set outpaint top distance")
|
| 30 |
-
outpaint_distance_bottom: int | None = Field(-1, description="Set outpaint bottom distance")
|
| 31 |
-
image_prompts: List[ImagePromptJson | ImagePrompt] = []
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class ImgPromptRequestJson(ImgInpaintOrOutpaintRequestJson):
|
| 35 |
-
"""img prompt request json"""
|
| 36 |
-
input_image: str | None = Field(None, description="Init image for inpaint or outpaint as base64")
|
| 37 |
-
image_prompts: List[ImagePromptJson | ImagePrompt]
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class Text2ImgRequestWithPrompt(CommonRequest):
|
| 41 |
-
"""text to image request with prompt"""
|
| 42 |
-
image_prompts: List[ImagePromptJson] = []
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class ImgUpscaleOrVaryRequestJson(CommonRequest):
|
| 46 |
-
"""img upscale or vary request json"""
|
| 47 |
-
uov_method: UpscaleOrVaryMethod = UpscaleOrVaryMethod.upscale_2
|
| 48 |
-
upscale_value: float | None = Field(1.0, ge=1.0, le=5.0, description="Upscale custom value, 1.0 for default value")
|
| 49 |
-
input_image: str = Field(description="Init image for upscale or outpaint as base64")
|
| 50 |
-
image_prompts: List[ImagePromptJson | ImagePrompt] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/parameters.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
from typing import Dict, List, Tuple
|
| 2 |
-
import numpy as np
|
| 3 |
-
import copy
|
| 4 |
-
|
| 5 |
-
from fooocusapi.models.common.requests import AdvancedParams
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class ImageGenerationParams:
|
| 9 |
-
def __init__(
|
| 10 |
-
self,
|
| 11 |
-
prompt: str,
|
| 12 |
-
negative_prompt: str,
|
| 13 |
-
style_selections: List[str],
|
| 14 |
-
performance_selection: str,
|
| 15 |
-
aspect_ratios_selection: str,
|
| 16 |
-
image_number: int,
|
| 17 |
-
image_seed: int | None,
|
| 18 |
-
sharpness: float,
|
| 19 |
-
guidance_scale: float,
|
| 20 |
-
base_model_name: str,
|
| 21 |
-
refiner_model_name: str,
|
| 22 |
-
refiner_switch: float,
|
| 23 |
-
loras: List[Tuple[str, float]],
|
| 24 |
-
uov_input_image: np.ndarray | None,
|
| 25 |
-
uov_method: str,
|
| 26 |
-
upscale_value: float | None,
|
| 27 |
-
outpaint_selections: List[str],
|
| 28 |
-
outpaint_distance_left: int,
|
| 29 |
-
outpaint_distance_right: int,
|
| 30 |
-
outpaint_distance_top: int,
|
| 31 |
-
outpaint_distance_bottom: int,
|
| 32 |
-
inpaint_input_image: Dict[str, np.ndarray] | None,
|
| 33 |
-
inpaint_additional_prompt: str | None,
|
| 34 |
-
image_prompts: List[Tuple[np.ndarray, float, float, str]],
|
| 35 |
-
advanced_params: List[any] | None,
|
| 36 |
-
save_extension: str,
|
| 37 |
-
save_meta: bool,
|
| 38 |
-
meta_scheme: str,
|
| 39 |
-
save_name: str,
|
| 40 |
-
require_base64: bool,
|
| 41 |
-
):
|
| 42 |
-
self.prompt = prompt
|
| 43 |
-
self.negative_prompt = negative_prompt
|
| 44 |
-
self.style_selections = style_selections
|
| 45 |
-
self.performance_selection = performance_selection
|
| 46 |
-
self.aspect_ratios_selection = aspect_ratios_selection
|
| 47 |
-
self.image_number = image_number
|
| 48 |
-
self.image_seed = image_seed
|
| 49 |
-
self.sharpness = sharpness
|
| 50 |
-
self.guidance_scale = guidance_scale
|
| 51 |
-
self.base_model_name = base_model_name
|
| 52 |
-
self.refiner_model_name = refiner_model_name
|
| 53 |
-
self.refiner_switch = refiner_switch
|
| 54 |
-
self.loras = loras
|
| 55 |
-
self.uov_input_image = uov_input_image
|
| 56 |
-
self.uov_method = uov_method
|
| 57 |
-
self.upscale_value = upscale_value
|
| 58 |
-
self.outpaint_selections = outpaint_selections
|
| 59 |
-
self.outpaint_distance_left = outpaint_distance_left
|
| 60 |
-
self.outpaint_distance_right = outpaint_distance_right
|
| 61 |
-
self.outpaint_distance_top = outpaint_distance_top
|
| 62 |
-
self.outpaint_distance_bottom = outpaint_distance_bottom
|
| 63 |
-
self.inpaint_input_image = inpaint_input_image
|
| 64 |
-
self.inpaint_additional_prompt = inpaint_additional_prompt
|
| 65 |
-
self.image_prompts = image_prompts
|
| 66 |
-
self.save_extension = save_extension
|
| 67 |
-
self.save_meta = save_meta
|
| 68 |
-
self.meta_scheme = meta_scheme
|
| 69 |
-
self.save_name = save_name
|
| 70 |
-
self.require_base64 = require_base64
|
| 71 |
-
self.advanced_params = advanced_params
|
| 72 |
-
|
| 73 |
-
if self.advanced_params is None:
|
| 74 |
-
self.advanced_params = AdvancedParams()
|
| 75 |
-
|
| 76 |
-
# Auto set mixing_image_prompt_and_inpaint to True
|
| 77 |
-
if len(self.image_prompts) > 0 and self.inpaint_input_image is not None:
|
| 78 |
-
print("Mixing Image Prompts and Inpaint Enabled")
|
| 79 |
-
self.advanced_params.mixing_image_prompt_and_inpaint = True
|
| 80 |
-
if len(self.image_prompts) > 0 and self.uov_input_image is not None:
|
| 81 |
-
print("Mixing Image Prompts and Vary Upscale Enabled")
|
| 82 |
-
self.advanced_params.mixing_image_prompt_and_vary_upscale = True
|
| 83 |
-
|
| 84 |
-
def to_dict(self):
|
| 85 |
-
"""
|
| 86 |
-
Convert the ImageGenerationParams object to a dictionary.
|
| 87 |
-
Args:
|
| 88 |
-
self:
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
self to dict
|
| 92 |
-
"""
|
| 93 |
-
obj_dict = copy.deepcopy(self)
|
| 94 |
-
return obj_dict.__dict__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/routes/__init__.py
DELETED
|
File without changes
|
fooocusapi/routes/generate_v1.py
DELETED
|
@@ -1,186 +0,0 @@
|
|
| 1 |
-
"""Generate API V1 routes
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
from typing import List, Optional
|
| 5 |
-
from fastapi import APIRouter, Depends, Header, Query, UploadFile
|
| 6 |
-
from fastapi.params import File
|
| 7 |
-
|
| 8 |
-
from modules.util import HWC3
|
| 9 |
-
|
| 10 |
-
from fooocusapi.models.common.base import DescribeImageType
|
| 11 |
-
from fooocusapi.utils.api_utils import api_key_auth
|
| 12 |
-
|
| 13 |
-
from fooocusapi.models.common.requests import CommonRequest as Text2ImgRequest
|
| 14 |
-
from fooocusapi.models.requests_v1 import (
|
| 15 |
-
ImgUpscaleOrVaryRequest,
|
| 16 |
-
ImgPromptRequest,
|
| 17 |
-
ImgInpaintOrOutpaintRequest
|
| 18 |
-
)
|
| 19 |
-
from fooocusapi.models.common.response import (
|
| 20 |
-
AsyncJobResponse,
|
| 21 |
-
GeneratedImageResult,
|
| 22 |
-
DescribeImageResponse,
|
| 23 |
-
StopResponse
|
| 24 |
-
)
|
| 25 |
-
from fooocusapi.utils.call_worker import call_worker
|
| 26 |
-
from fooocusapi.utils.img_utils import read_input_image
|
| 27 |
-
from fooocusapi.configs.default import img_generate_responses
|
| 28 |
-
from fooocusapi.worker import process_stop
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
secure_router = APIRouter(
|
| 32 |
-
dependencies=[Depends(api_key_auth)]
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def stop_worker():
|
| 37 |
-
"""Interrupt worker process"""
|
| 38 |
-
process_stop()
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
@secure_router.post(
|
| 42 |
-
path="/v1/generation/text-to-image",
|
| 43 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 44 |
-
responses=img_generate_responses,
|
| 45 |
-
tags=["GenerateV1"])
|
| 46 |
-
def text2img_generation(
|
| 47 |
-
req: Text2ImgRequest,
|
| 48 |
-
accept: str = Header(None),
|
| 49 |
-
accept_query: str | None = Query(
|
| 50 |
-
None, alias='accept',
|
| 51 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 52 |
-
"""\nText to Image Generation\n
|
| 53 |
-
A text to image generation endpoint
|
| 54 |
-
Arguments:
|
| 55 |
-
req {Text2ImgRequest} -- Text to image generation request
|
| 56 |
-
accept {str} -- Accept header
|
| 57 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 58 |
-
returns:
|
| 59 |
-
Response -- img_generate_responses
|
| 60 |
-
"""
|
| 61 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 62 |
-
accept = accept_query
|
| 63 |
-
|
| 64 |
-
return call_worker(req, accept)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
@secure_router.post(
|
| 68 |
-
path="/v1/generation/image-upscale-vary",
|
| 69 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 70 |
-
responses=img_generate_responses,
|
| 71 |
-
tags=["GenerateV1"])
|
| 72 |
-
def img_upscale_or_vary(
|
| 73 |
-
input_image: UploadFile,
|
| 74 |
-
req: ImgUpscaleOrVaryRequest = Depends(ImgUpscaleOrVaryRequest.as_form),
|
| 75 |
-
accept: str = Header(None),
|
| 76 |
-
accept_query: str | None = Query(
|
| 77 |
-
None, alias='accept',
|
| 78 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 79 |
-
"""\nImage upscale or vary\n
|
| 80 |
-
Image upscale or vary
|
| 81 |
-
Arguments:
|
| 82 |
-
input_image {UploadFile} -- Input image file
|
| 83 |
-
req {ImgUpscaleOrVaryRequest} -- Request body
|
| 84 |
-
accept {str} -- Accept header
|
| 85 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 86 |
-
Returns:
|
| 87 |
-
Response -- img_generate_responses
|
| 88 |
-
"""
|
| 89 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 90 |
-
accept = accept_query
|
| 91 |
-
|
| 92 |
-
return call_worker(req, accept)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
@secure_router.post(
|
| 96 |
-
path="/v1/generation/image-inpaint-outpaint",
|
| 97 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 98 |
-
responses=img_generate_responses,
|
| 99 |
-
tags=["GenerateV1"])
|
| 100 |
-
def img_inpaint_or_outpaint(
|
| 101 |
-
input_image: UploadFile,
|
| 102 |
-
req: ImgInpaintOrOutpaintRequest = Depends(ImgInpaintOrOutpaintRequest.as_form),
|
| 103 |
-
accept: str = Header(None),
|
| 104 |
-
accept_query: str | None = Query(
|
| 105 |
-
None, alias='accept',
|
| 106 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 107 |
-
"""\nInpaint or outpaint\n
|
| 108 |
-
Inpaint or outpaint
|
| 109 |
-
Arguments:
|
| 110 |
-
input_image {UploadFile} -- Input image file
|
| 111 |
-
req {ImgInpaintOrOutpaintRequest} -- Request body
|
| 112 |
-
accept {str} -- Accept header
|
| 113 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 114 |
-
"""
|
| 115 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 116 |
-
accept = accept_query
|
| 117 |
-
|
| 118 |
-
return call_worker(req, accept)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
@secure_router.post(
|
| 122 |
-
path="/v1/generation/image-prompt",
|
| 123 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 124 |
-
responses=img_generate_responses,
|
| 125 |
-
tags=["GenerateV1"])
|
| 126 |
-
def img_prompt(
|
| 127 |
-
cn_img1: Optional[UploadFile] = File(None),
|
| 128 |
-
req: ImgPromptRequest = Depends(ImgPromptRequest.as_form),
|
| 129 |
-
accept: str = Header(None),
|
| 130 |
-
accept_query: str | None = Query(
|
| 131 |
-
None, alias='accept',
|
| 132 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 133 |
-
"""\nImage Prompt\n
|
| 134 |
-
Image Prompt
|
| 135 |
-
A prompt-based image generation.
|
| 136 |
-
Arguments:
|
| 137 |
-
cn_img1 {UploadFile} -- Input image file
|
| 138 |
-
req {ImgPromptRequest} -- Request body
|
| 139 |
-
accept {str} -- Accept header
|
| 140 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 141 |
-
Returns:
|
| 142 |
-
Response -- img_generate_responses
|
| 143 |
-
"""
|
| 144 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 145 |
-
accept = accept_query
|
| 146 |
-
|
| 147 |
-
return call_worker(req, accept)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
@secure_router.post(
|
| 151 |
-
path="/v1/tools/describe-image",
|
| 152 |
-
response_model=DescribeImageResponse,
|
| 153 |
-
tags=["GenerateV1"])
|
| 154 |
-
def describe_image(
|
| 155 |
-
image: UploadFile,
|
| 156 |
-
image_type: DescribeImageType = Query(
|
| 157 |
-
DescribeImageType.photo,
|
| 158 |
-
description="Image type, 'Photo' or 'Anime'")):
|
| 159 |
-
"""\nDescribe image\n
|
| 160 |
-
Describe image, Get tags from an image
|
| 161 |
-
Arguments:
|
| 162 |
-
image {UploadFile} -- Image to get tags
|
| 163 |
-
image_type {DescribeImageType} -- Image type, 'Photo' or 'Anime'
|
| 164 |
-
Returns:
|
| 165 |
-
DescribeImageResponse -- Describe image response, a string
|
| 166 |
-
"""
|
| 167 |
-
if image_type == DescribeImageType.photo:
|
| 168 |
-
from extras.interrogate import default_interrogator as default_interrogator_photo
|
| 169 |
-
interrogator = default_interrogator_photo
|
| 170 |
-
else:
|
| 171 |
-
from extras.wd14tagger import default_interrogator as default_interrogator_anime
|
| 172 |
-
interrogator = default_interrogator_anime
|
| 173 |
-
img = HWC3(read_input_image(image))
|
| 174 |
-
result = interrogator(img)
|
| 175 |
-
return DescribeImageResponse(describe=result)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
@secure_router.post(
|
| 179 |
-
path="/v1/generation/stop",
|
| 180 |
-
response_model=StopResponse,
|
| 181 |
-
description="Job stopping",
|
| 182 |
-
tags=["Default"])
|
| 183 |
-
def stop():
|
| 184 |
-
"""Interrupt worker"""
|
| 185 |
-
stop_worker()
|
| 186 |
-
return StopResponse(msg="success")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/routes/generate_v2.py
DELETED
|
@@ -1,199 +0,0 @@
|
|
| 1 |
-
"""Generate API V2 routes
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
from typing import List
|
| 5 |
-
from fastapi import APIRouter, Depends, Header, Query
|
| 6 |
-
|
| 7 |
-
from fooocusapi.utils.api_utils import api_key_auth
|
| 8 |
-
from fooocusapi.models.requests_v1 import ImagePrompt
|
| 9 |
-
from fooocusapi.models.requests_v2 import (
|
| 10 |
-
ImgInpaintOrOutpaintRequestJson,
|
| 11 |
-
ImgPromptRequestJson,
|
| 12 |
-
Text2ImgRequestWithPrompt,
|
| 13 |
-
ImgUpscaleOrVaryRequestJson
|
| 14 |
-
)
|
| 15 |
-
from fooocusapi.models.common.response import (
|
| 16 |
-
AsyncJobResponse,
|
| 17 |
-
GeneratedImageResult
|
| 18 |
-
)
|
| 19 |
-
from fooocusapi.utils.call_worker import call_worker
|
| 20 |
-
from fooocusapi.utils.img_utils import base64_to_stream
|
| 21 |
-
from fooocusapi.configs.default import img_generate_responses
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
secure_router = APIRouter(
|
| 25 |
-
dependencies=[Depends(api_key_auth)]
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
@secure_router.post(
|
| 30 |
-
path="/v2/generation/text-to-image-with-ip",
|
| 31 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 32 |
-
responses=img_generate_responses,
|
| 33 |
-
tags=["GenerateV2"])
|
| 34 |
-
def text_to_img_with_ip(
|
| 35 |
-
req: Text2ImgRequestWithPrompt,
|
| 36 |
-
accept: str = Header(None),
|
| 37 |
-
accept_query: str | None = Query(
|
| 38 |
-
default=None, alias='accept',
|
| 39 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 40 |
-
"""\nText to image with prompt\n
|
| 41 |
-
Text to image with prompt
|
| 42 |
-
Arguments:
|
| 43 |
-
req {Text2ImgRequestWithPrompt} -- Text to image generation request
|
| 44 |
-
accept {str} -- Accept header
|
| 45 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 46 |
-
Returns:
|
| 47 |
-
Response -- img_generate_responses
|
| 48 |
-
"""
|
| 49 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 50 |
-
accept = accept_query
|
| 51 |
-
|
| 52 |
-
default_image_prompt = ImagePrompt(cn_img=None)
|
| 53 |
-
image_prompts_files: List[ImagePrompt] = []
|
| 54 |
-
for image_prompt in req.image_prompts:
|
| 55 |
-
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
| 56 |
-
image = ImagePrompt(
|
| 57 |
-
cn_img=image_prompt.cn_img,
|
| 58 |
-
cn_stop=image_prompt.cn_stop,
|
| 59 |
-
cn_weight=image_prompt.cn_weight,
|
| 60 |
-
cn_type=image_prompt.cn_type)
|
| 61 |
-
image_prompts_files.append(image)
|
| 62 |
-
|
| 63 |
-
while len(image_prompts_files) <= 4:
|
| 64 |
-
image_prompts_files.append(default_image_prompt)
|
| 65 |
-
|
| 66 |
-
req.image_prompts = image_prompts_files
|
| 67 |
-
|
| 68 |
-
return call_worker(req, accept)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
@secure_router.post(
|
| 72 |
-
path="/v2/generation/image-upscale-vary",
|
| 73 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 74 |
-
responses=img_generate_responses,
|
| 75 |
-
tags=["GenerateV2"])
|
| 76 |
-
def img_upscale_or_vary(
|
| 77 |
-
req: ImgUpscaleOrVaryRequestJson,
|
| 78 |
-
accept: str = Header(None),
|
| 79 |
-
accept_query: str | None = Query(
|
| 80 |
-
None, alias='accept', description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 81 |
-
"""\nImage upscale or vary\n
|
| 82 |
-
Image upscale or vary
|
| 83 |
-
Arguments:
|
| 84 |
-
req {ImgUpscaleOrVaryRequestJson} -- Image upscale or vary request
|
| 85 |
-
accept {str} -- Accept header
|
| 86 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 87 |
-
Returns:
|
| 88 |
-
Response -- img_generate_responses
|
| 89 |
-
"""
|
| 90 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 91 |
-
accept = accept_query
|
| 92 |
-
|
| 93 |
-
req.input_image = base64_to_stream(req.input_image)
|
| 94 |
-
|
| 95 |
-
default_image_prompt = ImagePrompt(cn_img=None)
|
| 96 |
-
image_prompts_files: List[ImagePrompt] = []
|
| 97 |
-
for image_prompt in req.image_prompts:
|
| 98 |
-
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
| 99 |
-
image = ImagePrompt(
|
| 100 |
-
cn_img=image_prompt.cn_img,
|
| 101 |
-
cn_stop=image_prompt.cn_stop,
|
| 102 |
-
cn_weight=image_prompt.cn_weight,
|
| 103 |
-
cn_type=image_prompt.cn_type)
|
| 104 |
-
image_prompts_files.append(image)
|
| 105 |
-
while len(image_prompts_files) <= 4:
|
| 106 |
-
image_prompts_files.append(default_image_prompt)
|
| 107 |
-
req.image_prompts = image_prompts_files
|
| 108 |
-
|
| 109 |
-
return call_worker(req, accept)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
@secure_router.post(
|
| 113 |
-
path="/v2/generation/image-inpaint-outpaint",
|
| 114 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 115 |
-
responses=img_generate_responses,
|
| 116 |
-
tags=["GenerateV2"])
|
| 117 |
-
def img_inpaint_or_outpaint(
|
| 118 |
-
req: ImgInpaintOrOutpaintRequestJson,
|
| 119 |
-
accept: str = Header(None),
|
| 120 |
-
accept_query: str | None = Query(
|
| 121 |
-
None, alias='accept',
|
| 122 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 123 |
-
"""\nInpaint or outpaint\n
|
| 124 |
-
Inpaint or outpaint
|
| 125 |
-
Arguments:
|
| 126 |
-
req {ImgInpaintOrOutpaintRequestJson} -- Request body
|
| 127 |
-
accept {str} -- Accept header
|
| 128 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 129 |
-
Returns:
|
| 130 |
-
Response -- img_generate_responses
|
| 131 |
-
"""
|
| 132 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 133 |
-
accept = accept_query
|
| 134 |
-
|
| 135 |
-
req.input_image = base64_to_stream(req.input_image)
|
| 136 |
-
if req.input_mask is not None:
|
| 137 |
-
req.input_mask = base64_to_stream(req.input_mask)
|
| 138 |
-
default_image_prompt = ImagePrompt(cn_img=None)
|
| 139 |
-
image_prompts_files: List[ImagePrompt] = []
|
| 140 |
-
for image_prompt in req.image_prompts:
|
| 141 |
-
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
| 142 |
-
image = ImagePrompt(
|
| 143 |
-
cn_img=image_prompt.cn_img,
|
| 144 |
-
cn_stop=image_prompt.cn_stop,
|
| 145 |
-
cn_weight=image_prompt.cn_weight,
|
| 146 |
-
cn_type=image_prompt.cn_type)
|
| 147 |
-
image_prompts_files.append(image)
|
| 148 |
-
while len(image_prompts_files) <= 4:
|
| 149 |
-
image_prompts_files.append(default_image_prompt)
|
| 150 |
-
req.image_prompts = image_prompts_files
|
| 151 |
-
|
| 152 |
-
return call_worker(req, accept)
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
@secure_router.post(
|
| 156 |
-
path="/v2/generation/image-prompt",
|
| 157 |
-
response_model=List[GeneratedImageResult] | AsyncJobResponse,
|
| 158 |
-
responses=img_generate_responses,
|
| 159 |
-
tags=["GenerateV2"])
|
| 160 |
-
def img_prompt(
|
| 161 |
-
req: ImgPromptRequestJson,
|
| 162 |
-
accept: str = Header(None),
|
| 163 |
-
accept_query: str | None = Query(
|
| 164 |
-
None, alias='accept',
|
| 165 |
-
description="Parameter to override 'Accept' header, 'image/png' for output bytes")):
|
| 166 |
-
"""\nImage prompt\n
|
| 167 |
-
Image prompt generation
|
| 168 |
-
Arguments:
|
| 169 |
-
req {ImgPromptRequest} -- Request body
|
| 170 |
-
accept {str} -- Accept header
|
| 171 |
-
accept_query {str} -- Parameter to override 'Accept' header, 'image/png' for output bytes
|
| 172 |
-
Returns:
|
| 173 |
-
Response -- img_generate_responses
|
| 174 |
-
"""
|
| 175 |
-
if accept_query is not None and len(accept_query) > 0:
|
| 176 |
-
accept = accept_query
|
| 177 |
-
|
| 178 |
-
if req.input_image is not None:
|
| 179 |
-
req.input_image = base64_to_stream(req.input_image)
|
| 180 |
-
if req.input_mask is not None:
|
| 181 |
-
req.input_mask = base64_to_stream(req.input_mask)
|
| 182 |
-
|
| 183 |
-
default_image_prompt = ImagePrompt(cn_img=None)
|
| 184 |
-
image_prompts_files: List[ImagePrompt] = []
|
| 185 |
-
for image_prompt in req.image_prompts:
|
| 186 |
-
image_prompt.cn_img = base64_to_stream(image_prompt.cn_img)
|
| 187 |
-
image = ImagePrompt(
|
| 188 |
-
cn_img=image_prompt.cn_img,
|
| 189 |
-
cn_stop=image_prompt.cn_stop,
|
| 190 |
-
cn_weight=image_prompt.cn_weight,
|
| 191 |
-
cn_type=image_prompt.cn_type)
|
| 192 |
-
image_prompts_files.append(image)
|
| 193 |
-
|
| 194 |
-
while len(image_prompts_files) <= 4:
|
| 195 |
-
image_prompts_files.append(default_image_prompt)
|
| 196 |
-
|
| 197 |
-
req.image_prompts = image_prompts_files
|
| 198 |
-
|
| 199 |
-
return call_worker(req, accept)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/routes/query.py
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
"""Query API"""
|
| 2 |
-
from typing import List
|
| 3 |
-
from fastapi import Depends, Response, APIRouter
|
| 4 |
-
|
| 5 |
-
from fooocusapi.args import args
|
| 6 |
-
|
| 7 |
-
from fooocusapi.models.common.requests import QueryJobRequest
|
| 8 |
-
from fooocusapi.models.common.response import (
|
| 9 |
-
AsyncJobResponse,
|
| 10 |
-
JobHistoryInfo,
|
| 11 |
-
JobQueueInfo,
|
| 12 |
-
JobHistoryResponse,
|
| 13 |
-
AllModelNamesResponse
|
| 14 |
-
)
|
| 15 |
-
from fooocusapi.models.common.task import AsyncJobStage
|
| 16 |
-
|
| 17 |
-
from fooocusapi.utils.api_utils import generate_async_output, api_key_auth
|
| 18 |
-
from fooocusapi.task_queue import TaskType
|
| 19 |
-
from fooocusapi.worker import worker_queue
|
| 20 |
-
|
| 21 |
-
secure_router = APIRouter(dependencies=[Depends(api_key_auth)])
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@secure_router.get(path="/", tags=['Query'])
|
| 25 |
-
def home():
|
| 26 |
-
"""Home page"""
|
| 27 |
-
return Response(
|
| 28 |
-
content='Swagger-UI to: <a href="/docs">/docs</a>',
|
| 29 |
-
media_type="text/html"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@secure_router.get(
|
| 34 |
-
path="/ping",
|
| 35 |
-
description="Returns a simple 'pong'",
|
| 36 |
-
tags=['Query'])
|
| 37 |
-
async def ping():
|
| 38 |
-
"""\nPing\n
|
| 39 |
-
Ping page, just to check if the fastapi is up.
|
| 40 |
-
Instant return correct, does not mean the service is available.
|
| 41 |
-
Returns:
|
| 42 |
-
A simple string pong
|
| 43 |
-
"""
|
| 44 |
-
return 'pong'
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
@secure_router.get(
|
| 48 |
-
path="/v1/generation/query-job",
|
| 49 |
-
response_model=AsyncJobResponse,
|
| 50 |
-
description="Query async generation job",
|
| 51 |
-
tags=['Query'])
|
| 52 |
-
def query_job(req: QueryJobRequest = Depends()):
|
| 53 |
-
"""query job info by id"""
|
| 54 |
-
queue_task = worker_queue.get_task(req.job_id, True)
|
| 55 |
-
if queue_task is None:
|
| 56 |
-
result = AsyncJobResponse(
|
| 57 |
-
job_id="",
|
| 58 |
-
job_type=TaskType.not_found,
|
| 59 |
-
job_stage=AsyncJobStage.error,
|
| 60 |
-
job_progress=0,
|
| 61 |
-
job_status="Job not found")
|
| 62 |
-
content = result.model_dump_json()
|
| 63 |
-
return Response(content=content, media_type='application/json', status_code=404)
|
| 64 |
-
return generate_async_output(queue_task, req.require_step_preview)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
@secure_router.get(
|
| 68 |
-
path="/v1/generation/job-queue",
|
| 69 |
-
response_model=JobQueueInfo,
|
| 70 |
-
description="Query job queue info",
|
| 71 |
-
tags=['Query'])
|
| 72 |
-
def job_queue():
|
| 73 |
-
"""Query job queue info"""
|
| 74 |
-
queue = JobQueueInfo(
|
| 75 |
-
running_size=len(worker_queue.queue),
|
| 76 |
-
finished_size=len(worker_queue.history),
|
| 77 |
-
last_job_id=worker_queue.last_job_id
|
| 78 |
-
)
|
| 79 |
-
return queue
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
@secure_router.get(
|
| 83 |
-
path="/v1/generation/job-history",
|
| 84 |
-
response_model=JobHistoryResponse | dict,
|
| 85 |
-
description="Query historical job data",
|
| 86 |
-
tags=["Query"])
|
| 87 |
-
def get_history(job_id: str = None, page: int = 0, page_size: int = 20):
|
| 88 |
-
"""Fetch and return the historical tasks"""
|
| 89 |
-
queue = [
|
| 90 |
-
JobHistoryInfo(
|
| 91 |
-
job_id=item.job_id,
|
| 92 |
-
is_finished=item.is_finished
|
| 93 |
-
) for item in worker_queue.queue
|
| 94 |
-
]
|
| 95 |
-
if not args.persistent:
|
| 96 |
-
history = [
|
| 97 |
-
JobHistoryInfo(
|
| 98 |
-
job_id=item.job_id,
|
| 99 |
-
is_finished=item.is_finished
|
| 100 |
-
) for item in worker_queue.history
|
| 101 |
-
]
|
| 102 |
-
return JobHistoryResponse(history=history, queue=queue)
|
| 103 |
-
|
| 104 |
-
from fooocusapi.sql_client import query_history
|
| 105 |
-
history = query_history(task_id=job_id, page=page, page_size=page_size)
|
| 106 |
-
return {
|
| 107 |
-
"history": history,
|
| 108 |
-
"queue": queue
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
@secure_router.get(
|
| 113 |
-
path="/v1/engines/all-models",
|
| 114 |
-
response_model=AllModelNamesResponse,
|
| 115 |
-
description="Get all filenames of base model and lora",
|
| 116 |
-
tags=["Query"])
|
| 117 |
-
def all_models():
|
| 118 |
-
"""Refresh and return all models"""
|
| 119 |
-
from modules import config
|
| 120 |
-
config.update_files()
|
| 121 |
-
models = AllModelNamesResponse(
|
| 122 |
-
model_filenames=config.model_filenames,
|
| 123 |
-
lora_filenames=config.lora_filenames)
|
| 124 |
-
return models
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
@secure_router.get(
|
| 128 |
-
path="/v1/engines/styles",
|
| 129 |
-
response_model=List[str],
|
| 130 |
-
description="Get all legal Fooocus styles",
|
| 131 |
-
tags=['Query'])
|
| 132 |
-
def all_styles():
|
| 133 |
-
"""Return all available styles"""
|
| 134 |
-
from modules.sdxl_styles import legal_style_names
|
| 135 |
-
return legal_style_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/sql_client.py
DELETED
|
@@ -1,269 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SQLite client for Fooocus API
|
| 3 |
-
"""
|
| 4 |
-
import os
|
| 5 |
-
import time
|
| 6 |
-
import platform
|
| 7 |
-
from datetime import datetime
|
| 8 |
-
from typing import Optional
|
| 9 |
-
import copy
|
| 10 |
-
|
| 11 |
-
from sqlalchemy import Integer, Float, VARCHAR, Boolean, JSON, Text, create_engine
|
| 12 |
-
from sqlalchemy.orm import declarative_base, Session, Mapped, mapped_column
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
Base = declarative_base()
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
if platform.system().lower() == "windows":
|
| 19 |
-
default_sqlite_db_path = os.path.join(
|
| 20 |
-
os.path.dirname(__file__), "../database.db"
|
| 21 |
-
).replace("\\", "/")
|
| 22 |
-
else:
|
| 23 |
-
default_sqlite_db_path = os.path.join(os.path.dirname(__file__), "../database.db")
|
| 24 |
-
|
| 25 |
-
connection_uri = os.environ.get(
|
| 26 |
-
"FOOOCUS_DB_CONF", f"sqlite:///{default_sqlite_db_path}"
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class GenerateRecord(Base):
|
| 31 |
-
"""
|
| 32 |
-
GenerateRecord
|
| 33 |
-
|
| 34 |
-
__tablename__ = 'generate_record'
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
__tablename__ = "generate_record"
|
| 38 |
-
|
| 39 |
-
task_id: Mapped[str] = mapped_column(VARCHAR(255), nullable=False, primary_key=True)
|
| 40 |
-
task_type: Mapped[str] = mapped_column(Text, nullable=False)
|
| 41 |
-
result_url: Mapped[str] = mapped_column(Text, nullable=True)
|
| 42 |
-
finish_reason: Mapped[str] = mapped_column(Text, nullable=True)
|
| 43 |
-
date_time: Mapped[int] = mapped_column(Integer, nullable=False)
|
| 44 |
-
|
| 45 |
-
prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 46 |
-
negative_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 47 |
-
style_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
| 48 |
-
performance_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 49 |
-
aspect_ratios_selection: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 50 |
-
base_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 51 |
-
refiner_model_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 52 |
-
refiner_switch: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
| 53 |
-
loras: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
| 54 |
-
image_number: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
| 55 |
-
image_seed: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
| 56 |
-
sharpness: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
| 57 |
-
guidance_scale: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
| 58 |
-
advanced_params: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
| 59 |
-
|
| 60 |
-
input_image: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 61 |
-
input_mask: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 62 |
-
image_prompts: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
| 63 |
-
inpaint_additional_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 64 |
-
outpaint_selections: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
| 65 |
-
outpaint_distance_left: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
| 66 |
-
outpaint_distance_right: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
| 67 |
-
outpaint_distance_top: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
| 68 |
-
outpaint_distance_bottom: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
| 69 |
-
uov_method: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 70 |
-
upscale_value: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
| 71 |
-
|
| 72 |
-
webhook_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
| 73 |
-
require_base64: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
| 74 |
-
async_process: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
| 75 |
-
|
| 76 |
-
def __repr__(self) -> str:
|
| 77 |
-
return f"GenerateRecord(task_id={self.task_id!r}, task_type={self.task_type!r}, \
|
| 78 |
-
result_url={self.result_url!r}, finish_reason={self.finish_reason!r}, date_time={self.date_time!r}, \
|
| 79 |
-
prompt={self.prompt!r}, negative_prompt={self.negative_prompt!r}, style_selections={self.style_selections!r}, performance_selection={self.performance_selection!r}, \
|
| 80 |
-
aspect_ratios_selection={self.aspect_ratios_selection!r}, base_model_name={self.base_model_name!r}, \
|
| 81 |
-
refiner_model_name={self.refiner_model_name!r}, refiner_switch={self.refiner_switch!r}, loras={self.loras!r}, \
|
| 82 |
-
image_number={self.image_number!r}, image_seed={self.image_seed!r}, sharpness={self.sharpness!r}, \
|
| 83 |
-
guidance_scale={self.guidance_scale!r}, advanced_params={self.advanced_params!r}, input_image={self.input_image!r}, \
|
| 84 |
-
input_mask={self.input_mask!r}, image_prompts={self.image_prompts!r}, inpaint_additional_prompt={self.inpaint_additional_prompt!r}, \
|
| 85 |
-
outpaint_selections={self.outpaint_selections!r}, outpaint_distance_left={self.outpaint_distance_left!r}, outpaint_distance_right={self.outpaint_distance_right!r}, \
|
| 86 |
-
outpaint_distance_top={self.outpaint_distance_top!r}, outpaint_distance_bottom={self.outpaint_distance_bottom!r}, uov_method={self.uov_method!r}, \
|
| 87 |
-
upscale_value={self.upscale_value!r}, webhook_url={self.webhook_url!r}, require_base64={self.require_base64!r}, \
|
| 88 |
-
async_process={self.async_process!r})"
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
engine = create_engine(connection_uri)
|
| 92 |
-
|
| 93 |
-
session = Session(engine)
|
| 94 |
-
Base.metadata.create_all(engine, checkfirst=True)
|
| 95 |
-
session.close()
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def convert_to_dict_list(obj_list: list[object]) -> list[dict]:
|
| 99 |
-
"""
|
| 100 |
-
Convert a list of objects to a list of dictionaries.
|
| 101 |
-
Args:
|
| 102 |
-
obj_list:
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
dict_list:
|
| 106 |
-
"""
|
| 107 |
-
dict_list = []
|
| 108 |
-
for obj in obj_list:
|
| 109 |
-
# 将对象属性转化为字典键值对
|
| 110 |
-
dict_obj = {}
|
| 111 |
-
for attr, value in vars(obj).items():
|
| 112 |
-
if (
|
| 113 |
-
not callable(value)
|
| 114 |
-
and not attr.startswith("__")
|
| 115 |
-
and not attr.startswith("_")
|
| 116 |
-
):
|
| 117 |
-
dict_obj[attr] = value
|
| 118 |
-
task_info = {
|
| 119 |
-
"task_id": obj.task_id,
|
| 120 |
-
"task_type": obj.task_type,
|
| 121 |
-
"result_url": obj.result_url,
|
| 122 |
-
"finish_reason": obj.finish_reason,
|
| 123 |
-
"date_time": datetime.fromtimestamp(obj.date_time).strftime(
|
| 124 |
-
"%Y-%m-%d %H:%M:%S"
|
| 125 |
-
),
|
| 126 |
-
}
|
| 127 |
-
del dict_obj["task_id"]
|
| 128 |
-
del dict_obj["task_type"]
|
| 129 |
-
del dict_obj["result_url"]
|
| 130 |
-
del dict_obj["finish_reason"]
|
| 131 |
-
del dict_obj["date_time"]
|
| 132 |
-
dict_list.append({"params": dict_obj, "task_info": task_info})
|
| 133 |
-
return dict_list
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
class MySQLAlchemy:
|
| 137 |
-
"""
|
| 138 |
-
MySQLAlchemy, a toolkit for managing SQLAlchemy connections and sessions.
|
| 139 |
-
|
| 140 |
-
:param uri: SQLAlchemy connection URI
|
| 141 |
-
"""
|
| 142 |
-
|
| 143 |
-
def __init__(self, uri: str):
|
| 144 |
-
# 'mysql+pymysql://{username}:{password}@{host}:{port}/{database}'
|
| 145 |
-
self.engine = create_engine(uri)
|
| 146 |
-
self.session = Session(self.engine)
|
| 147 |
-
|
| 148 |
-
def store_history(self, record: dict) -> None:
|
| 149 |
-
"""
|
| 150 |
-
Store history to database
|
| 151 |
-
:param record:
|
| 152 |
-
:return:
|
| 153 |
-
"""
|
| 154 |
-
self.session.add_all([GenerateRecord(**record)])
|
| 155 |
-
self.session.commit()
|
| 156 |
-
|
| 157 |
-
def get_history(
|
| 158 |
-
self,
|
| 159 |
-
task_id: str = None,
|
| 160 |
-
page: int = 0,
|
| 161 |
-
page_size: int = 20,
|
| 162 |
-
order_by: str = "date_time",
|
| 163 |
-
) -> list:
|
| 164 |
-
"""
|
| 165 |
-
Get history from database
|
| 166 |
-
:param task_id:
|
| 167 |
-
:param page:
|
| 168 |
-
:param page_size:
|
| 169 |
-
:param order_by:
|
| 170 |
-
:return:
|
| 171 |
-
"""
|
| 172 |
-
if task_id is not None:
|
| 173 |
-
res = (
|
| 174 |
-
self.session.query(GenerateRecord)
|
| 175 |
-
.filter(GenerateRecord.task_id == task_id)
|
| 176 |
-
.all()
|
| 177 |
-
)
|
| 178 |
-
if len(res) == 0:
|
| 179 |
-
return []
|
| 180 |
-
return convert_to_dict_list(res)
|
| 181 |
-
|
| 182 |
-
res = (
|
| 183 |
-
self.session.query(GenerateRecord)
|
| 184 |
-
.order_by(getattr(GenerateRecord, order_by).desc())
|
| 185 |
-
.offset(page * page_size)
|
| 186 |
-
.limit(page_size)
|
| 187 |
-
.all()
|
| 188 |
-
)
|
| 189 |
-
if len(res) == 0:
|
| 190 |
-
return []
|
| 191 |
-
return convert_to_dict_list(res)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
db = MySQLAlchemy(uri=connection_uri)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def req_to_dict(req: dict) -> dict:
|
| 198 |
-
"""
|
| 199 |
-
Convert request to dictionary
|
| 200 |
-
Args:
|
| 201 |
-
req:
|
| 202 |
-
|
| 203 |
-
Returns:
|
| 204 |
-
|
| 205 |
-
"""
|
| 206 |
-
req["loras"] = [{"model_name": lora[0], "weight": lora[1]} for lora in req["loras"]]
|
| 207 |
-
# req["advanced_params"] = dict(zip(adv_params_keys, req["advanced_params"]))
|
| 208 |
-
req["image_prompts"] = [
|
| 209 |
-
{"cn_img": "", "cn_stop": image[1], "cn_weight": image[2], "cn_type": image[3]}
|
| 210 |
-
for image in req["image_prompts"]
|
| 211 |
-
]
|
| 212 |
-
del req["inpaint_input_image"]
|
| 213 |
-
del req["uov_input_image"]
|
| 214 |
-
return req
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def add_history(
|
| 218 |
-
params: dict, task_type: str, task_id: str, result_url: str, finish_reason: str
|
| 219 |
-
) -> None:
|
| 220 |
-
"""
|
| 221 |
-
Store history to database
|
| 222 |
-
Args:
|
| 223 |
-
params:
|
| 224 |
-
task_type:
|
| 225 |
-
task_id:
|
| 226 |
-
result_url:
|
| 227 |
-
finish_reason:
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
|
| 231 |
-
"""
|
| 232 |
-
adv = copy.deepcopy(params["advanced_params"])
|
| 233 |
-
params["advanced_params"] = adv.__dict__
|
| 234 |
-
params["date_time"] = int(time.time())
|
| 235 |
-
params["task_type"] = task_type
|
| 236 |
-
params["task_id"] = task_id
|
| 237 |
-
params["result_url"] = result_url
|
| 238 |
-
params["finish_reason"] = finish_reason
|
| 239 |
-
|
| 240 |
-
del params["inpaint_input_image"]
|
| 241 |
-
del params["uov_input_image"]
|
| 242 |
-
del params["save_extension"]
|
| 243 |
-
del params["save_meta"]
|
| 244 |
-
del params["save_name"]
|
| 245 |
-
del params["meta_scheme"]
|
| 246 |
-
|
| 247 |
-
db.store_history(params)
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def query_history(
|
| 251 |
-
task_id: str = None,
|
| 252 |
-
page: int = 0,
|
| 253 |
-
page_size: int = 20,
|
| 254 |
-
order_by: str = "date_time"
|
| 255 |
-
) -> list:
|
| 256 |
-
"""
|
| 257 |
-
Query history from database
|
| 258 |
-
Args:
|
| 259 |
-
task_id:
|
| 260 |
-
page:
|
| 261 |
-
page_size:
|
| 262 |
-
order_by:
|
| 263 |
-
|
| 264 |
-
Returns:
|
| 265 |
-
|
| 266 |
-
"""
|
| 267 |
-
return db.get_history(
|
| 268 |
-
task_id=task_id, page=page, page_size=page_size, order_by=order_by
|
| 269 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/task_queue.py
DELETED
|
@@ -1,323 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Task queue management
|
| 3 |
-
|
| 4 |
-
This module provides classes and functions for managing the task queue.
|
| 5 |
-
|
| 6 |
-
Classes:
|
| 7 |
-
QueueTask: A class representing a task in the queue.
|
| 8 |
-
TaskQueue: A class for managing the task queue.
|
| 9 |
-
"""
|
| 10 |
-
import uuid
|
| 11 |
-
import time
|
| 12 |
-
from typing import List, Tuple
|
| 13 |
-
import numpy as np
|
| 14 |
-
import requests
|
| 15 |
-
|
| 16 |
-
from fooocusapi.utils.file_utils import delete_output_file, get_file_serve_url
|
| 17 |
-
from fooocusapi.utils.img_utils import narray_to_base64img
|
| 18 |
-
from fooocusapi.utils.logger import logger
|
| 19 |
-
|
| 20 |
-
from fooocusapi.models.common.task import ImageGenerationResult, GenerationFinishReason
|
| 21 |
-
from fooocusapi.parameters import ImageGenerationParams
|
| 22 |
-
from fooocusapi.models.common.task import TaskType
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class QueueTask:
|
| 26 |
-
"""
|
| 27 |
-
A class representing a task in the queue.
|
| 28 |
-
|
| 29 |
-
Attributes:
|
| 30 |
-
job_id (str): The unique identifier for the task, generated by uuid.
|
| 31 |
-
task_type (TaskType): The type of task.
|
| 32 |
-
is_finished (bool): Indicates whether the task has been completed.
|
| 33 |
-
finish_progress (int): The progress of the task completion.
|
| 34 |
-
in_queue_mills (int): The time the task was added to the queue, in milliseconds.
|
| 35 |
-
start_mills (int): The time the task started, in milliseconds.
|
| 36 |
-
finish_mills (int): The time the task finished, in milliseconds.
|
| 37 |
-
finish_with_error (bool): Indicates whether the task finished with an error.
|
| 38 |
-
task_status (str): The status of the task.
|
| 39 |
-
task_step_preview (str): A list of step previews for the task.
|
| 40 |
-
task_result (List[ImageGenerationResult]): The result of the task.
|
| 41 |
-
error_message (str): The error message, if any.
|
| 42 |
-
webhook_url (str): The webhook URL, if any.
|
| 43 |
-
"""
|
| 44 |
-
|
| 45 |
-
job_id: str
|
| 46 |
-
task_type: TaskType
|
| 47 |
-
req_param: ImageGenerationParams
|
| 48 |
-
is_finished: bool = False
|
| 49 |
-
finish_progress: int = 0
|
| 50 |
-
in_queue_mills: int
|
| 51 |
-
start_mills: int = 0
|
| 52 |
-
finish_mills: int = 0
|
| 53 |
-
finish_with_error: bool = False
|
| 54 |
-
task_status: str | None = None
|
| 55 |
-
task_step_preview: str | None = None
|
| 56 |
-
task_result: List[ImageGenerationResult] = None
|
| 57 |
-
error_message: str | None = None
|
| 58 |
-
webhook_url: str | None = None # attribute for individual webhook_url
|
| 59 |
-
|
| 60 |
-
def __init__(
|
| 61 |
-
self,
|
| 62 |
-
job_id: str,
|
| 63 |
-
task_type: TaskType,
|
| 64 |
-
req_param: ImageGenerationParams,
|
| 65 |
-
webhook_url: str | None = None,
|
| 66 |
-
):
|
| 67 |
-
self.job_id = job_id
|
| 68 |
-
self.task_type = task_type
|
| 69 |
-
self.req_param = req_param
|
| 70 |
-
self.in_queue_mills = int(round(time.time() * 1000))
|
| 71 |
-
self.webhook_url = webhook_url
|
| 72 |
-
|
| 73 |
-
def set_progress(self, progress: int, status: str | None):
|
| 74 |
-
"""
|
| 75 |
-
Set progress and status
|
| 76 |
-
Arguments:
|
| 77 |
-
progress {int} -- progress
|
| 78 |
-
status {str} -- status
|
| 79 |
-
"""
|
| 80 |
-
progress = min(progress, 100)
|
| 81 |
-
self.finish_progress = progress
|
| 82 |
-
self.task_status = status
|
| 83 |
-
|
| 84 |
-
def set_step_preview(self, task_step_preview: str | None):
|
| 85 |
-
"""set step preview
|
| 86 |
-
Set step preview
|
| 87 |
-
Arguments:
|
| 88 |
-
task_step_preview {str} -- step preview
|
| 89 |
-
"""
|
| 90 |
-
self.task_step_preview = task_step_preview
|
| 91 |
-
|
| 92 |
-
def set_result(
|
| 93 |
-
self,
|
| 94 |
-
task_result: List[ImageGenerationResult],
|
| 95 |
-
finish_with_error: bool,
|
| 96 |
-
error_message: str | None = None,
|
| 97 |
-
):
|
| 98 |
-
"""set result
|
| 99 |
-
Set task result
|
| 100 |
-
Arguments:
|
| 101 |
-
task_result {List[ImageGenerationResult]} -- task result
|
| 102 |
-
finish_with_error {bool} -- finish with error
|
| 103 |
-
error_message {str} -- error message
|
| 104 |
-
"""
|
| 105 |
-
if not finish_with_error:
|
| 106 |
-
self.finish_progress = 100
|
| 107 |
-
self.task_status = "Finished"
|
| 108 |
-
self.task_result = task_result
|
| 109 |
-
self.finish_with_error = finish_with_error
|
| 110 |
-
self.error_message = error_message
|
| 111 |
-
|
| 112 |
-
def __str__(self) -> str:
|
| 113 |
-
return f"QueueTask(job_id={self.job_id}, task_type={self.task_type},\
|
| 114 |
-
is_finished={self.is_finished}, finished_progress={self.finish_progress}, \
|
| 115 |
-
in_queue_mills={self.in_queue_mills}, start_mills={self.start_mills}, \
|
| 116 |
-
finish_mills={self.finish_mills}, finish_with_error={self.finish_with_error}, \
|
| 117 |
-
error_message={self.error_message}, task_status={self.task_status}, \
|
| 118 |
-
task_step_preview={self.task_step_preview}, webhook_url={self.webhook_url})"
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class TaskQueue:
|
| 122 |
-
"""
|
| 123 |
-
TaskQueue is a queue of tasks that are waiting to be processed.
|
| 124 |
-
|
| 125 |
-
Attributes:
|
| 126 |
-
queue: List[QueueTask]
|
| 127 |
-
history: List[QueueTask]
|
| 128 |
-
last_job_id: str
|
| 129 |
-
webhook_url: str
|
| 130 |
-
persistent: bool
|
| 131 |
-
"""
|
| 132 |
-
|
| 133 |
-
queue: List[QueueTask] = []
|
| 134 |
-
history: List[QueueTask] = []
|
| 135 |
-
last_job_id: str = None
|
| 136 |
-
webhook_url: str | None = None
|
| 137 |
-
persistent: bool = False
|
| 138 |
-
|
| 139 |
-
def __init__(
|
| 140 |
-
self,
|
| 141 |
-
queue_size: int,
|
| 142 |
-
history_size: int,
|
| 143 |
-
webhook_url: str | None = None,
|
| 144 |
-
persistent: bool | None = False,
|
| 145 |
-
):
|
| 146 |
-
self.queue_size = queue_size
|
| 147 |
-
self.history_size = history_size
|
| 148 |
-
self.webhook_url = webhook_url
|
| 149 |
-
self.persistent = False if persistent is None else persistent
|
| 150 |
-
|
| 151 |
-
def add_task(
|
| 152 |
-
self,
|
| 153 |
-
task_type: TaskType,
|
| 154 |
-
req_param: ImageGenerationParams,
|
| 155 |
-
webhook_url: str | None = None,
|
| 156 |
-
) -> QueueTask | None:
|
| 157 |
-
"""
|
| 158 |
-
Create and add task to queue
|
| 159 |
-
:param task_type: task type
|
| 160 |
-
:param req_param: request parameters
|
| 161 |
-
:param webhook_url: webhook url
|
| 162 |
-
:returns: The created task's job_id, or None if reach the queue size limit
|
| 163 |
-
"""
|
| 164 |
-
if len(self.queue) >= self.queue_size:
|
| 165 |
-
return None
|
| 166 |
-
|
| 167 |
-
if isinstance(req_param, dict):
|
| 168 |
-
req_param = ImageGenerationParams(**req_param)
|
| 169 |
-
|
| 170 |
-
job_id = str(uuid.uuid4())
|
| 171 |
-
task = QueueTask(
|
| 172 |
-
job_id=job_id,
|
| 173 |
-
task_type=task_type,
|
| 174 |
-
req_param=req_param,
|
| 175 |
-
webhook_url=webhook_url,
|
| 176 |
-
)
|
| 177 |
-
self.queue.append(task)
|
| 178 |
-
self.last_job_id = job_id
|
| 179 |
-
return task
|
| 180 |
-
|
| 181 |
-
def get_task(self, job_id: str, include_history: bool = False) -> QueueTask | None:
|
| 182 |
-
"""
|
| 183 |
-
Get task by job_id
|
| 184 |
-
:param job_id: job id
|
| 185 |
-
:param include_history: whether to include history tasks
|
| 186 |
-
:returns: The task with the given job_id, or None if not found
|
| 187 |
-
"""
|
| 188 |
-
for task in self.queue:
|
| 189 |
-
if task.job_id == job_id:
|
| 190 |
-
return task
|
| 191 |
-
|
| 192 |
-
if include_history:
|
| 193 |
-
for task in self.history:
|
| 194 |
-
if task.job_id == job_id:
|
| 195 |
-
return task
|
| 196 |
-
|
| 197 |
-
return None
|
| 198 |
-
|
| 199 |
-
def is_task_ready_to_start(self, job_id: str) -> bool:
|
| 200 |
-
"""
|
| 201 |
-
Check if the task is ready to start
|
| 202 |
-
:param job_id: job id
|
| 203 |
-
:returns: True if the task is ready to start, False otherwise
|
| 204 |
-
"""
|
| 205 |
-
task = self.get_task(job_id)
|
| 206 |
-
if task is None:
|
| 207 |
-
return False
|
| 208 |
-
|
| 209 |
-
return self.queue[0].job_id == job_id
|
| 210 |
-
|
| 211 |
-
def is_task_finished(self, job_id: str) -> bool:
|
| 212 |
-
"""
|
| 213 |
-
Check if the task is finished
|
| 214 |
-
:param job_id: job id
|
| 215 |
-
:returns: True if the task is finished, False otherwise
|
| 216 |
-
"""
|
| 217 |
-
task = self.get_task(job_id, True)
|
| 218 |
-
if task is None:
|
| 219 |
-
return False
|
| 220 |
-
|
| 221 |
-
return task.is_finished
|
| 222 |
-
|
| 223 |
-
def start_task(self, job_id: str):
|
| 224 |
-
"""
|
| 225 |
-
Start task by job_id
|
| 226 |
-
:param job_id: job id
|
| 227 |
-
"""
|
| 228 |
-
task = self.get_task(job_id)
|
| 229 |
-
if task is not None:
|
| 230 |
-
task.start_mills = int(round(time.time() * 1000))
|
| 231 |
-
|
| 232 |
-
def finish_task(self, job_id: str):
|
| 233 |
-
"""
|
| 234 |
-
Finish task by job_id
|
| 235 |
-
:param job_id: job id
|
| 236 |
-
"""
|
| 237 |
-
task = self.get_task(job_id)
|
| 238 |
-
if task is not None:
|
| 239 |
-
task.is_finished = True
|
| 240 |
-
task.finish_mills = int(round(time.time() * 1000))
|
| 241 |
-
|
| 242 |
-
# Use the task's webhook_url if available, else use the default
|
| 243 |
-
webhook_url = task.webhook_url or self.webhook_url
|
| 244 |
-
|
| 245 |
-
data = {"job_id": task.job_id, "job_result": []}
|
| 246 |
-
|
| 247 |
-
if isinstance(task.task_result, List):
|
| 248 |
-
for item in task.task_result:
|
| 249 |
-
data["job_result"].append(
|
| 250 |
-
{
|
| 251 |
-
"url": get_file_serve_url(item.im) if item.im else None,
|
| 252 |
-
"seed": item.seed if item.seed else "-1",
|
| 253 |
-
}
|
| 254 |
-
)
|
| 255 |
-
|
| 256 |
-
# Send webhook
|
| 257 |
-
if task.is_finished and webhook_url:
|
| 258 |
-
try:
|
| 259 |
-
res = requests.post(webhook_url, json=data, timeout=15)
|
| 260 |
-
print(f"Call webhook response status: {res.status_code}")
|
| 261 |
-
except Exception as e:
|
| 262 |
-
print("Call webhook error:", e)
|
| 263 |
-
|
| 264 |
-
# Move task to history
|
| 265 |
-
self.queue.remove(task)
|
| 266 |
-
self.history.append(task)
|
| 267 |
-
|
| 268 |
-
# save history to database
|
| 269 |
-
if self.persistent:
|
| 270 |
-
from fooocusapi.sql_client import add_history
|
| 271 |
-
|
| 272 |
-
add_history(
|
| 273 |
-
params=task.req_param.to_dict(),
|
| 274 |
-
task_type=task.task_type.value,
|
| 275 |
-
task_id=task.job_id,
|
| 276 |
-
result_url=",".join([job["url"] for job in data["job_result"]]),
|
| 277 |
-
finish_reason=task.task_result[0].finish_reason.value,
|
| 278 |
-
)
|
| 279 |
-
|
| 280 |
-
# Clean history
|
| 281 |
-
if len(self.history) > self.history_size != 0:
|
| 282 |
-
removed_task = self.history.pop(0)
|
| 283 |
-
if isinstance(removed_task.task_result, List):
|
| 284 |
-
for item in removed_task.task_result:
|
| 285 |
-
if (
|
| 286 |
-
isinstance(item, ImageGenerationResult)
|
| 287 |
-
and item.finish_reason == GenerationFinishReason.success
|
| 288 |
-
and item.im is not None
|
| 289 |
-
):
|
| 290 |
-
delete_output_file(item.im)
|
| 291 |
-
logger.std_info(
|
| 292 |
-
f"[TaskQueue] Clean task history, remove task: {removed_task.job_id}"
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
class TaskOutputs:
|
| 297 |
-
"""
|
| 298 |
-
TaskOutputs is a container for task outputs
|
| 299 |
-
"""
|
| 300 |
-
|
| 301 |
-
outputs = []
|
| 302 |
-
|
| 303 |
-
def __init__(self, task: QueueTask):
|
| 304 |
-
self.task = task
|
| 305 |
-
|
| 306 |
-
def append(self, args: List[any]):
|
| 307 |
-
"""
|
| 308 |
-
Append output to task outputs list
|
| 309 |
-
:param args: output arguments
|
| 310 |
-
"""
|
| 311 |
-
self.outputs.append(args)
|
| 312 |
-
if len(args) >= 2:
|
| 313 |
-
if (
|
| 314 |
-
args[0] == "preview"
|
| 315 |
-
and isinstance(args[1], Tuple)
|
| 316 |
-
and len(args[1]) >= 2
|
| 317 |
-
):
|
| 318 |
-
number = args[1][0]
|
| 319 |
-
text = args[1][1]
|
| 320 |
-
self.task.set_progress(number, text)
|
| 321 |
-
if len(args[1]) >= 3 and isinstance(args[1][2], np.ndarray):
|
| 322 |
-
base64_preview_img = narray_to_base64img(args[1][2])
|
| 323 |
-
self.task.set_step_preview(base64_preview_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/api_utils.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
| 1 |
-
"""some utils for api"""
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
-
from fastapi import Response
|
| 5 |
-
from fastapi.security import APIKeyHeader
|
| 6 |
-
from fastapi import HTTPException, Security
|
| 7 |
-
|
| 8 |
-
from modules import flags
|
| 9 |
-
from modules import config
|
| 10 |
-
from modules.sdxl_styles import legal_style_names
|
| 11 |
-
|
| 12 |
-
from fooocusapi.args import args
|
| 13 |
-
from fooocusapi.utils.img_utils import read_input_image
|
| 14 |
-
from fooocusapi.utils.file_utils import (
|
| 15 |
-
get_file_serve_url,
|
| 16 |
-
output_file_to_base64img,
|
| 17 |
-
output_file_to_bytesimg
|
| 18 |
-
)
|
| 19 |
-
from fooocusapi.utils.logger import logger
|
| 20 |
-
from fooocusapi.models.common.requests import (
|
| 21 |
-
CommonRequest as Text2ImgRequest
|
| 22 |
-
)
|
| 23 |
-
from fooocusapi.models.common.response import (
|
| 24 |
-
AsyncJobResponse,
|
| 25 |
-
AsyncJobStage,
|
| 26 |
-
GeneratedImageResult
|
| 27 |
-
)
|
| 28 |
-
from fooocusapi.models.requests_v1 import (
|
| 29 |
-
ImgInpaintOrOutpaintRequest,
|
| 30 |
-
ImgPromptRequest,
|
| 31 |
-
ImgUpscaleOrVaryRequest
|
| 32 |
-
)
|
| 33 |
-
from fooocusapi.models.requests_v2 import (
|
| 34 |
-
Text2ImgRequestWithPrompt,
|
| 35 |
-
ImgInpaintOrOutpaintRequestJson,
|
| 36 |
-
ImgUpscaleOrVaryRequestJson,
|
| 37 |
-
ImgPromptRequestJson
|
| 38 |
-
)
|
| 39 |
-
from fooocusapi.models.common.task import (
|
| 40 |
-
ImageGenerationResult,
|
| 41 |
-
GenerationFinishReason
|
| 42 |
-
)
|
| 43 |
-
from fooocusapi.configs.default import (
|
| 44 |
-
default_inpaint_engine_version,
|
| 45 |
-
default_sampler,
|
| 46 |
-
default_scheduler,
|
| 47 |
-
default_base_model_name,
|
| 48 |
-
default_refiner_model_name
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
from fooocusapi.parameters import ImageGenerationParams
|
| 52 |
-
from fooocusapi.task_queue import QueueTask
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def api_key_auth(apikey: str = Security(api_key_header)):
|
| 59 |
-
"""
|
| 60 |
-
Check if the API key is valid, API key is not required if no API key is set
|
| 61 |
-
Args:
|
| 62 |
-
apikey: API key
|
| 63 |
-
returns:
|
| 64 |
-
None if API key is not set, otherwise raise HTTPException
|
| 65 |
-
"""
|
| 66 |
-
if args.apikey is None:
|
| 67 |
-
return # Skip API key check if no API key is set
|
| 68 |
-
if apikey != args.apikey:
|
| 69 |
-
raise HTTPException(status_code=403, detail="Forbidden")
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
|
| 73 |
-
"""
|
| 74 |
-
Convert Request to ImageGenerationParams
|
| 75 |
-
Args:
|
| 76 |
-
req: Request, Text2ImgRequest and classes inherited from Text2ImgRequest
|
| 77 |
-
returns:
|
| 78 |
-
ImageGenerationParams
|
| 79 |
-
"""
|
| 80 |
-
config.update_files()
|
| 81 |
-
if req.base_model_name is not None:
|
| 82 |
-
if req.base_model_name not in config.model_filenames:
|
| 83 |
-
logger.std_warn(f"[Warning] Wrong base_model_name input: {req.base_model_name}, using default")
|
| 84 |
-
req.base_model_name = default_base_model_name
|
| 85 |
-
|
| 86 |
-
if req.refiner_model_name is not None and req.refiner_model_name != 'None':
|
| 87 |
-
if req.refiner_model_name not in config.model_filenames:
|
| 88 |
-
logger.std_warn(f"[Warning] Wrong refiner_model_name input: {req.refiner_model_name}, using default")
|
| 89 |
-
req.refiner_model_name = default_refiner_model_name
|
| 90 |
-
|
| 91 |
-
for lora in req.loras:
|
| 92 |
-
if lora.model_name != 'None' and lora.model_name not in config.lora_filenames:
|
| 93 |
-
logger.std_warn(f"[Warning] Wrong lora model_name input: {lora.model_name}, using 'None'")
|
| 94 |
-
lora.model_name = 'None'
|
| 95 |
-
|
| 96 |
-
prompt = req.prompt
|
| 97 |
-
negative_prompt = req.negative_prompt
|
| 98 |
-
style_selections = [
|
| 99 |
-
s for s in req.style_selections if s in legal_style_names]
|
| 100 |
-
performance_selection = req.performance_selection.value
|
| 101 |
-
aspect_ratios_selection = req.aspect_ratios_selection
|
| 102 |
-
image_number = req.image_number
|
| 103 |
-
image_seed = None if req.image_seed == -1 else req.image_seed
|
| 104 |
-
sharpness = req.sharpness
|
| 105 |
-
guidance_scale = req.guidance_scale
|
| 106 |
-
base_model_name = req.base_model_name
|
| 107 |
-
refiner_model_name = req.refiner_model_name
|
| 108 |
-
refiner_switch = req.refiner_switch
|
| 109 |
-
loras = [(lora.model_name, lora.weight) for lora in req.loras]
|
| 110 |
-
uov_input_image = None
|
| 111 |
-
if not isinstance(req, Text2ImgRequestWithPrompt):
|
| 112 |
-
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
|
| 113 |
-
uov_input_image = read_input_image(req.input_image)
|
| 114 |
-
uov_method = flags.disabled if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.uov_method.value
|
| 115 |
-
upscale_value = None if not isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)) else req.upscale_value
|
| 116 |
-
outpaint_selections = [] if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else [
|
| 117 |
-
s.value for s in req.outpaint_selections]
|
| 118 |
-
outpaint_distance_left = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_left
|
| 119 |
-
outpaint_distance_right = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_right
|
| 120 |
-
outpaint_distance_top = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_top
|
| 121 |
-
outpaint_distance_bottom = None if not isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) else req.outpaint_distance_bottom
|
| 122 |
-
|
| 123 |
-
if refiner_model_name == '':
|
| 124 |
-
refiner_model_name = 'None'
|
| 125 |
-
|
| 126 |
-
inpaint_input_image = None
|
| 127 |
-
inpaint_additional_prompt = None
|
| 128 |
-
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)) and req.input_image is not None:
|
| 129 |
-
inpaint_additional_prompt = req.inpaint_additional_prompt
|
| 130 |
-
input_image = read_input_image(req.input_image)
|
| 131 |
-
input_mask = None
|
| 132 |
-
if req.input_mask is not None:
|
| 133 |
-
input_mask = read_input_image(req.input_mask)
|
| 134 |
-
inpaint_input_image = {
|
| 135 |
-
'image': input_image,
|
| 136 |
-
'mask': input_mask
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
image_prompts = []
|
| 140 |
-
if isinstance(req, (ImgInpaintOrOutpaintRequestJson, ImgPromptRequest, ImgPromptRequestJson, ImgUpscaleOrVaryRequestJson, Text2ImgRequestWithPrompt)):
|
| 141 |
-
# Auto set mixing_image_prompt_and_inpaint to True
|
| 142 |
-
if len(req.image_prompts) > 0 and uov_input_image is not None:
|
| 143 |
-
print("[INFO] Mixing image prompt and vary upscale is set to True")
|
| 144 |
-
req.advanced_params.mixing_image_prompt_and_vary_upscale = True
|
| 145 |
-
elif len(req.image_prompts) > 0 and not isinstance(req, Text2ImgRequestWithPrompt) and req.input_image is not None:
|
| 146 |
-
print("[INFO] Mixing image prompt and inpaint is set to True")
|
| 147 |
-
req.advanced_params.mixing_image_prompt_and_inpaint = True
|
| 148 |
-
|
| 149 |
-
for img_prompt in req.image_prompts:
|
| 150 |
-
if img_prompt.cn_img is not None:
|
| 151 |
-
cn_img = read_input_image(img_prompt.cn_img)
|
| 152 |
-
if img_prompt.cn_stop is None or img_prompt.cn_stop == 0:
|
| 153 |
-
img_prompt.cn_stop = flags.default_parameters[img_prompt.cn_type.value][0]
|
| 154 |
-
if img_prompt.cn_weight is None or img_prompt.cn_weight == 0:
|
| 155 |
-
img_prompt.cn_weight = flags.default_parameters[img_prompt.cn_type.value][1]
|
| 156 |
-
image_prompts.append(
|
| 157 |
-
(cn_img, img_prompt.cn_stop, img_prompt.cn_weight, img_prompt.cn_type.value))
|
| 158 |
-
|
| 159 |
-
advanced_params = None
|
| 160 |
-
if req.advanced_params is not None:
|
| 161 |
-
adp = req.advanced_params
|
| 162 |
-
|
| 163 |
-
if adp.refiner_swap_method not in ['joint', 'separate', 'vae']:
|
| 164 |
-
print(f"[Warning] Wrong refiner_swap_method input: {adp.refiner_swap_method}, using default")
|
| 165 |
-
adp.refiner_swap_method = 'joint'
|
| 166 |
-
|
| 167 |
-
if adp.sampler_name not in flags.sampler_list:
|
| 168 |
-
print(f"[Warning] Wrong sampler_name input: {adp.sampler_name}, using default")
|
| 169 |
-
adp.sampler_name = default_sampler
|
| 170 |
-
|
| 171 |
-
if adp.scheduler_name not in flags.scheduler_list:
|
| 172 |
-
print(f"[Warning] Wrong scheduler_name input: {adp.scheduler_name}, using default")
|
| 173 |
-
adp.scheduler_name = default_scheduler
|
| 174 |
-
|
| 175 |
-
if adp.inpaint_engine not in flags.inpaint_engine_versions:
|
| 176 |
-
print(f"[Warning] Wrong inpaint_engine input: {adp.inpaint_engine}, using default")
|
| 177 |
-
adp.inpaint_engine = default_inpaint_engine_version
|
| 178 |
-
|
| 179 |
-
advanced_params = adp
|
| 180 |
-
|
| 181 |
-
return ImageGenerationParams(
|
| 182 |
-
prompt=prompt,
|
| 183 |
-
negative_prompt=negative_prompt,
|
| 184 |
-
style_selections=style_selections,
|
| 185 |
-
performance_selection=performance_selection,
|
| 186 |
-
aspect_ratios_selection=aspect_ratios_selection,
|
| 187 |
-
image_number=image_number,
|
| 188 |
-
image_seed=image_seed,
|
| 189 |
-
sharpness=sharpness,
|
| 190 |
-
guidance_scale=guidance_scale,
|
| 191 |
-
base_model_name=base_model_name,
|
| 192 |
-
refiner_model_name=refiner_model_name,
|
| 193 |
-
refiner_switch=refiner_switch,
|
| 194 |
-
loras=loras,
|
| 195 |
-
uov_input_image=uov_input_image,
|
| 196 |
-
uov_method=uov_method,
|
| 197 |
-
upscale_value=upscale_value,
|
| 198 |
-
outpaint_selections=outpaint_selections,
|
| 199 |
-
outpaint_distance_left=outpaint_distance_left,
|
| 200 |
-
outpaint_distance_right=outpaint_distance_right,
|
| 201 |
-
outpaint_distance_top=outpaint_distance_top,
|
| 202 |
-
outpaint_distance_bottom=outpaint_distance_bottom,
|
| 203 |
-
inpaint_input_image=inpaint_input_image,
|
| 204 |
-
inpaint_additional_prompt=inpaint_additional_prompt,
|
| 205 |
-
image_prompts=image_prompts,
|
| 206 |
-
advanced_params=advanced_params,
|
| 207 |
-
save_meta=req.save_meta,
|
| 208 |
-
meta_scheme=req.meta_scheme,
|
| 209 |
-
save_name=req.save_name,
|
| 210 |
-
save_extension=req.save_extension,
|
| 211 |
-
require_base64=req.require_base64,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
def generate_async_output(
|
| 216 |
-
task: QueueTask,
|
| 217 |
-
require_step_preview: bool = False) -> AsyncJobResponse:
|
| 218 |
-
"""
|
| 219 |
-
Generate output for async job
|
| 220 |
-
Arguments:
|
| 221 |
-
task: QueueTask
|
| 222 |
-
require_step_preview: bool
|
| 223 |
-
Returns:
|
| 224 |
-
AsyncJobResponse
|
| 225 |
-
"""
|
| 226 |
-
job_stage = AsyncJobStage.running
|
| 227 |
-
job_result = None
|
| 228 |
-
|
| 229 |
-
if task.start_mills == 0:
|
| 230 |
-
job_stage = AsyncJobStage.waiting
|
| 231 |
-
|
| 232 |
-
if task.is_finished:
|
| 233 |
-
if task.finish_with_error:
|
| 234 |
-
job_stage = AsyncJobStage.error
|
| 235 |
-
elif task.task_result is not None:
|
| 236 |
-
job_stage = AsyncJobStage.success
|
| 237 |
-
job_result = generate_image_result_output(task.task_result, task.req_param.require_base64)
|
| 238 |
-
|
| 239 |
-
result = AsyncJobResponse(
|
| 240 |
-
job_id=task.job_id,
|
| 241 |
-
job_type=task.task_type,
|
| 242 |
-
job_stage=job_stage,
|
| 243 |
-
job_progress=task.finish_progress,
|
| 244 |
-
job_status=task.task_status,
|
| 245 |
-
job_step_preview=task.task_step_preview if require_step_preview else None,
|
| 246 |
-
job_result=job_result)
|
| 247 |
-
return result
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def generate_streaming_output(results: List[ImageGenerationResult]) -> Response:
|
| 251 |
-
"""
|
| 252 |
-
Generate streaming output for image generation results.
|
| 253 |
-
Args:
|
| 254 |
-
results (List[ImageGenerationResult]): List of image generation results.
|
| 255 |
-
Returns:
|
| 256 |
-
Response: Streaming response object, bytes image.
|
| 257 |
-
"""
|
| 258 |
-
if len(results) == 0:
|
| 259 |
-
return Response(status_code=500)
|
| 260 |
-
result = results[0]
|
| 261 |
-
if result.finish_reason == GenerationFinishReason.queue_is_full:
|
| 262 |
-
return Response(status_code=409, content=result.finish_reason.value)
|
| 263 |
-
if result.finish_reason == GenerationFinishReason.user_cancel:
|
| 264 |
-
return Response(status_code=400, content=result.finish_reason.value)
|
| 265 |
-
if result.finish_reason == GenerationFinishReason.error:
|
| 266 |
-
return Response(status_code=500, content=result.finish_reason.value)
|
| 267 |
-
|
| 268 |
-
img_bytes = output_file_to_bytesimg(results[0].im)
|
| 269 |
-
return Response(img_bytes, media_type='image/png')
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
def generate_image_result_output(
|
| 273 |
-
results: List[ImageGenerationResult],
|
| 274 |
-
require_base64: bool) -> List[GeneratedImageResult]:
|
| 275 |
-
"""
|
| 276 |
-
Generate image result output
|
| 277 |
-
Arguments:
|
| 278 |
-
results: List[ImageGenerationResult]
|
| 279 |
-
require_base64: bool
|
| 280 |
-
Returns:
|
| 281 |
-
List[GeneratedImageResult]
|
| 282 |
-
"""
|
| 283 |
-
results = [
|
| 284 |
-
GeneratedImageResult(
|
| 285 |
-
base64=output_file_to_base64img(item.im) if require_base64 else None,
|
| 286 |
-
url=get_file_serve_url(item.im),
|
| 287 |
-
seed=str(item.seed),
|
| 288 |
-
finish_reason=item.finish_reason
|
| 289 |
-
) for item in results
|
| 290 |
-
]
|
| 291 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/call_worker.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
"""function for call generate worker"""
|
| 2 |
-
from typing import List
|
| 3 |
-
from fastapi import Response
|
| 4 |
-
|
| 5 |
-
from fooocusapi.models.common.requests import (
|
| 6 |
-
CommonRequest as Text2ImgRequest
|
| 7 |
-
)
|
| 8 |
-
from fooocusapi.models.common.response import (
|
| 9 |
-
AsyncJobResponse,
|
| 10 |
-
GeneratedImageResult
|
| 11 |
-
)
|
| 12 |
-
from fooocusapi.models.common.task import (
|
| 13 |
-
GenerationFinishReason,
|
| 14 |
-
ImageGenerationResult,
|
| 15 |
-
AsyncJobStage,
|
| 16 |
-
TaskType
|
| 17 |
-
)
|
| 18 |
-
from fooocusapi.utils.api_utils import (
|
| 19 |
-
req_to_params,
|
| 20 |
-
generate_async_output,
|
| 21 |
-
generate_streaming_output,
|
| 22 |
-
generate_image_result_output
|
| 23 |
-
)
|
| 24 |
-
from fooocusapi.models.requests_v1 import (
|
| 25 |
-
ImgUpscaleOrVaryRequest,
|
| 26 |
-
ImgPromptRequest,
|
| 27 |
-
ImgInpaintOrOutpaintRequest
|
| 28 |
-
)
|
| 29 |
-
from fooocusapi.models.requests_v2 import (
|
| 30 |
-
ImgInpaintOrOutpaintRequestJson,
|
| 31 |
-
ImgPromptRequestJson,
|
| 32 |
-
ImgUpscaleOrVaryRequestJson
|
| 33 |
-
)
|
| 34 |
-
from fooocusapi.worker import worker_queue, blocking_get_task_result
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get_task_type(req: Text2ImgRequest) -> TaskType:
|
| 38 |
-
"""return task type"""
|
| 39 |
-
if isinstance(req, (ImgUpscaleOrVaryRequest, ImgUpscaleOrVaryRequestJson)):
|
| 40 |
-
return TaskType.img_uov
|
| 41 |
-
if isinstance(req, (ImgPromptRequest, ImgPromptRequestJson)):
|
| 42 |
-
return TaskType.img_prompt
|
| 43 |
-
if isinstance(req, (ImgInpaintOrOutpaintRequest, ImgInpaintOrOutpaintRequestJson)):
|
| 44 |
-
return TaskType.img_inpaint_outpaint
|
| 45 |
-
return TaskType.text_2_img
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def call_worker(req: Text2ImgRequest, accept: str) -> Response | AsyncJobResponse | List[GeneratedImageResult]:
|
| 49 |
-
"""call generation worker"""
|
| 50 |
-
if accept == 'image/png':
|
| 51 |
-
streaming_output = True
|
| 52 |
-
# image_number auto set to 1 in streaming mode
|
| 53 |
-
req.image_number = 1
|
| 54 |
-
else:
|
| 55 |
-
streaming_output = False
|
| 56 |
-
|
| 57 |
-
task_type = get_task_type(req)
|
| 58 |
-
params = req_to_params(req)
|
| 59 |
-
async_task = worker_queue.add_task(task_type, params, req.webhook_url)
|
| 60 |
-
|
| 61 |
-
if async_task is None:
|
| 62 |
-
# add to worker queue failed
|
| 63 |
-
failure_results = [
|
| 64 |
-
ImageGenerationResult(
|
| 65 |
-
im=None,
|
| 66 |
-
seed='',
|
| 67 |
-
finish_reason=GenerationFinishReason.queue_is_full
|
| 68 |
-
)]
|
| 69 |
-
|
| 70 |
-
if streaming_output:
|
| 71 |
-
return generate_streaming_output(failure_results)
|
| 72 |
-
if req.async_process:
|
| 73 |
-
return AsyncJobResponse(
|
| 74 |
-
job_id='',
|
| 75 |
-
job_type=get_task_type(req),
|
| 76 |
-
job_stage=AsyncJobStage.error,
|
| 77 |
-
job_progress=0,
|
| 78 |
-
job_status=None,
|
| 79 |
-
job_step_preview=None,
|
| 80 |
-
job_result=[GeneratedImageResult(
|
| 81 |
-
base64=None,
|
| 82 |
-
url=None,
|
| 83 |
-
seed='',
|
| 84 |
-
finish_reason=GenerationFinishReason.queue_is_full
|
| 85 |
-
)])
|
| 86 |
-
return generate_image_result_output(failure_results, False)
|
| 87 |
-
|
| 88 |
-
if req.async_process:
|
| 89 |
-
# return async response directly
|
| 90 |
-
return generate_async_output(async_task)
|
| 91 |
-
|
| 92 |
-
# blocking get generation result
|
| 93 |
-
results = blocking_get_task_result(async_task.job_id)
|
| 94 |
-
|
| 95 |
-
if streaming_output:
|
| 96 |
-
return generate_streaming_output(results)
|
| 97 |
-
return generate_image_result_output(results, req.require_base64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/file_utils.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
""" File utils
|
| 4 |
-
|
| 5 |
-
Use for managing generated files
|
| 6 |
-
|
| 7 |
-
@file: file_utils.py
|
| 8 |
-
@author: Konie
|
| 9 |
-
@update: 2024-03-22
|
| 10 |
-
"""
|
| 11 |
-
import base64
|
| 12 |
-
import datetime
|
| 13 |
-
from io import BytesIO
|
| 14 |
-
import os
|
| 15 |
-
import json
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
import numpy as np
|
| 18 |
-
from PIL import Image
|
| 19 |
-
from PIL.PngImagePlugin import PngInfo
|
| 20 |
-
|
| 21 |
-
from fooocusapi.utils.logger import logger
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
output_dir = os.path.abspath(os.path.join(
|
| 25 |
-
os.path.dirname(__file__), '../..', 'outputs', 'files'))
|
| 26 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 27 |
-
|
| 28 |
-
STATIC_SERVER_BASE = 'http://127.0.0.1:8888/files/'
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def save_output_file(
|
| 32 |
-
img: np.ndarray,
|
| 33 |
-
image_meta: dict = None,
|
| 34 |
-
image_name: str = '',
|
| 35 |
-
extension: str = 'png') -> str:
|
| 36 |
-
"""
|
| 37 |
-
Save np image to file
|
| 38 |
-
Args:
|
| 39 |
-
img: np.ndarray image to save
|
| 40 |
-
image_meta: dict of image metadata
|
| 41 |
-
image_name: str of image name
|
| 42 |
-
extension: str of image extension
|
| 43 |
-
Returns:
|
| 44 |
-
str of file name
|
| 45 |
-
"""
|
| 46 |
-
current_time = datetime.datetime.now()
|
| 47 |
-
date_string = current_time.strftime("%Y-%m-%d")
|
| 48 |
-
|
| 49 |
-
filename = os.path.join(date_string, image_name + '.' + extension)
|
| 50 |
-
file_path = os.path.join(output_dir, filename)
|
| 51 |
-
|
| 52 |
-
if extension not in ['png', 'jpg', 'webp']:
|
| 53 |
-
extension = 'png'
|
| 54 |
-
image_format = Image.registered_extensions()['.'+extension]
|
| 55 |
-
|
| 56 |
-
if image_meta is None:
|
| 57 |
-
image_meta = {}
|
| 58 |
-
|
| 59 |
-
meta = None
|
| 60 |
-
if extension == 'png'and image_meta != {}:
|
| 61 |
-
meta = PngInfo()
|
| 62 |
-
meta.add_text("parameters", json.dumps(image_meta))
|
| 63 |
-
meta.add_text("fooocus_scheme", image_meta['metadata_scheme'])
|
| 64 |
-
|
| 65 |
-
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 66 |
-
Image.fromarray(img).save(
|
| 67 |
-
file_path,
|
| 68 |
-
format=image_format,
|
| 69 |
-
pnginfo=meta,
|
| 70 |
-
optimize=True)
|
| 71 |
-
return Path(filename).as_posix()
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def delete_output_file(filename: str):
|
| 75 |
-
"""
|
| 76 |
-
Delete files specified in the output directory
|
| 77 |
-
Args:
|
| 78 |
-
filename: str of file name
|
| 79 |
-
"""
|
| 80 |
-
file_path = os.path.join(output_dir, filename)
|
| 81 |
-
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
| 82 |
-
logger.std_warn(f'[Fooocus API] {filename} not exists or is not a file')
|
| 83 |
-
try:
|
| 84 |
-
os.remove(file_path)
|
| 85 |
-
logger.std_info(f'[Fooocus API] Delete output file: {filename}')
|
| 86 |
-
except OSError:
|
| 87 |
-
logger.std_error(f'[Fooocus API] Delete output file failed: {filename}')
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def output_file_to_base64img(filename: str | None) -> str | None:
|
| 91 |
-
"""
|
| 92 |
-
Convert an image file to a base64 string.
|
| 93 |
-
Args:
|
| 94 |
-
filename: str of file name
|
| 95 |
-
return: str of base64 string
|
| 96 |
-
"""
|
| 97 |
-
if filename is None:
|
| 98 |
-
return None
|
| 99 |
-
file_path = os.path.join(output_dir, filename)
|
| 100 |
-
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
| 101 |
-
return None
|
| 102 |
-
|
| 103 |
-
ext = filename.split('.')[-1]
|
| 104 |
-
if ext.lower() not in ['png', 'jpg', 'webp', 'jpeg']:
|
| 105 |
-
ext = 'png'
|
| 106 |
-
img = Image.open(file_path)
|
| 107 |
-
output_buffer = BytesIO()
|
| 108 |
-
img.save(output_buffer, format=ext.upper())
|
| 109 |
-
byte_data = output_buffer.getvalue()
|
| 110 |
-
base64_str = base64.b64encode(byte_data).decode('utf-8')
|
| 111 |
-
return f"data:image/{ext};base64," + base64_str
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def output_file_to_bytesimg(filename: str | None) -> bytes | None:
|
| 115 |
-
"""
|
| 116 |
-
Convert an image file to a bytes string.
|
| 117 |
-
Args:
|
| 118 |
-
filename: str of file name
|
| 119 |
-
return: bytes of image data
|
| 120 |
-
"""
|
| 121 |
-
if filename is None:
|
| 122 |
-
return None
|
| 123 |
-
file_path = os.path.join(output_dir, filename)
|
| 124 |
-
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
| 125 |
-
return None
|
| 126 |
-
|
| 127 |
-
img = Image.open(file_path)
|
| 128 |
-
output_buffer = BytesIO()
|
| 129 |
-
img.save(output_buffer, format='PNG')
|
| 130 |
-
byte_data = output_buffer.getvalue()
|
| 131 |
-
return byte_data
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def get_file_serve_url(filename: str | None) -> str | None:
|
| 135 |
-
"""
|
| 136 |
-
Get the static serve url of an image file.
|
| 137 |
-
Args:
|
| 138 |
-
filename: str of file name
|
| 139 |
-
return: str of static serve url
|
| 140 |
-
"""
|
| 141 |
-
if filename is None:
|
| 142 |
-
return None
|
| 143 |
-
return STATIC_SERVER_BASE + filename.replace('\\', '/')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/img_utils.py
DELETED
|
@@ -1,198 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Image process utils. Used to verify, convert and store Images.
|
| 3 |
-
|
| 4 |
-
@file: img_utils.py
|
| 5 |
-
@author: Konie
|
| 6 |
-
@update: 2024-03-23
|
| 7 |
-
"""
|
| 8 |
-
import base64
|
| 9 |
-
from io import BytesIO
|
| 10 |
-
from fastapi import UploadFile
|
| 11 |
-
from PIL import Image
|
| 12 |
-
|
| 13 |
-
import requests
|
| 14 |
-
import numpy as np
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def upload2base64(image: UploadFile) -> str | None:
|
| 18 |
-
"""
|
| 19 |
-
Convert UploadFile obj to base64 string
|
| 20 |
-
Args:
|
| 21 |
-
image (UploadFile): UploadFile obj
|
| 22 |
-
Returns:
|
| 23 |
-
str: base64 string, None for None
|
| 24 |
-
"""
|
| 25 |
-
if image is None:
|
| 26 |
-
return None
|
| 27 |
-
image_bytes = image.file.read()
|
| 28 |
-
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
| 29 |
-
return image_base64
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def narray_to_base64img(narray: np.ndarray) -> str | None:
|
| 33 |
-
"""
|
| 34 |
-
Convert numpy array to base64 image string.
|
| 35 |
-
Args:
|
| 36 |
-
narray: numpy array
|
| 37 |
-
Returns:
|
| 38 |
-
base64 image string
|
| 39 |
-
"""
|
| 40 |
-
if narray is None:
|
| 41 |
-
return None
|
| 42 |
-
|
| 43 |
-
img = Image.fromarray(narray)
|
| 44 |
-
output_buffer = BytesIO()
|
| 45 |
-
img.save(output_buffer, format='PNG')
|
| 46 |
-
byte_data = output_buffer.getvalue()
|
| 47 |
-
base64_str = base64.b64encode(byte_data).decode('utf-8')
|
| 48 |
-
return base64_str
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def narray_to_bytesimg(narray) -> bytes | None:
|
| 52 |
-
"""
|
| 53 |
-
Convert numpy array to bytes image.
|
| 54 |
-
Args:
|
| 55 |
-
narray: numpy array
|
| 56 |
-
Returns:
|
| 57 |
-
bytes image
|
| 58 |
-
"""
|
| 59 |
-
if narray is None:
|
| 60 |
-
return None
|
| 61 |
-
|
| 62 |
-
img = Image.fromarray(narray)
|
| 63 |
-
output_buffer = BytesIO()
|
| 64 |
-
img.save(output_buffer, format='PNG')
|
| 65 |
-
byte_data = output_buffer.getvalue()
|
| 66 |
-
return byte_data
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def read_input_image(input_image: UploadFile | str | None) -> np.ndarray | None:
|
| 70 |
-
"""
|
| 71 |
-
Read input image from UploadFile or base64 string.
|
| 72 |
-
Args:
|
| 73 |
-
input_image: UploadFile, or base64 image string, or None
|
| 74 |
-
Returns:
|
| 75 |
-
numpy array of image
|
| 76 |
-
"""
|
| 77 |
-
if input_image is None or input_image == '':
|
| 78 |
-
return None
|
| 79 |
-
if isinstance(input_image, str):
|
| 80 |
-
input_image_bytes = base64.b64decode(input_image)
|
| 81 |
-
else:
|
| 82 |
-
input_image_bytes = input_image.file.read()
|
| 83 |
-
pil_image = Image.open(BytesIO(input_image_bytes))
|
| 84 |
-
image = np.array(pil_image)
|
| 85 |
-
return image
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def base64_to_stream(image: str) -> UploadFile | None:
|
| 89 |
-
"""
|
| 90 |
-
Convert base64 image string to UploadFile.
|
| 91 |
-
Args:
|
| 92 |
-
image: base64 image string
|
| 93 |
-
Returns:
|
| 94 |
-
UploadFile or None
|
| 95 |
-
"""
|
| 96 |
-
if image in ['', None, 'None', 'none', 'string', 'null']:
|
| 97 |
-
return None
|
| 98 |
-
if image.startswith('http'):
|
| 99 |
-
return get_check_image(url=image)
|
| 100 |
-
if image.startswith('data:image'):
|
| 101 |
-
image = image.split(sep=',', maxsplit=1)[1]
|
| 102 |
-
image_bytes = base64.b64decode(image)
|
| 103 |
-
byte_stream = BytesIO()
|
| 104 |
-
byte_stream.write(image_bytes)
|
| 105 |
-
byte_stream.seek(0)
|
| 106 |
-
return UploadFile(file=byte_stream)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def get_check_image(url: str) -> UploadFile | None:
|
| 110 |
-
"""
|
| 111 |
-
Get image from url and check if it's valid.
|
| 112 |
-
Args:
|
| 113 |
-
url: image url
|
| 114 |
-
Returns:
|
| 115 |
-
UploadFile or None
|
| 116 |
-
"""
|
| 117 |
-
if url == '':
|
| 118 |
-
return None
|
| 119 |
-
headers = {
|
| 120 |
-
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
|
| 121 |
-
}
|
| 122 |
-
try:
|
| 123 |
-
response = requests.get(url, headers=headers, timeout=10)
|
| 124 |
-
binary_image = response.content
|
| 125 |
-
except Exception:
|
| 126 |
-
return None
|
| 127 |
-
try:
|
| 128 |
-
buffer = BytesIO(binary_image)
|
| 129 |
-
Image.open(buffer) # This validates the image
|
| 130 |
-
except Exception:
|
| 131 |
-
return None
|
| 132 |
-
byte_stream = BytesIO()
|
| 133 |
-
byte_stream.write(binary_image)
|
| 134 |
-
byte_stream.seek(0)
|
| 135 |
-
return UploadFile(file=byte_stream)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def bytes_image_to_io(binary_image: bytes) -> BytesIO | None:
|
| 139 |
-
"""
|
| 140 |
-
Convert bytes image to BytesIO.
|
| 141 |
-
Args:
|
| 142 |
-
binary_image: bytes image
|
| 143 |
-
Returns:
|
| 144 |
-
BytesIO or None
|
| 145 |
-
"""
|
| 146 |
-
try:
|
| 147 |
-
buffer = BytesIO(binary_image)
|
| 148 |
-
Image.open(buffer)
|
| 149 |
-
except Exception:
|
| 150 |
-
return None
|
| 151 |
-
byte_stream = BytesIO()
|
| 152 |
-
byte_stream.write(binary_image)
|
| 153 |
-
byte_stream.seek(0)
|
| 154 |
-
return byte_stream
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def bytes_to_base64img(byte_data: bytes) -> str | None:
|
| 158 |
-
"""
|
| 159 |
-
Convert bytes image to base64 image string.
|
| 160 |
-
Args:
|
| 161 |
-
byte_data: bytes image
|
| 162 |
-
Returns:
|
| 163 |
-
base64 image string or None
|
| 164 |
-
"""
|
| 165 |
-
if byte_data is None:
|
| 166 |
-
return None
|
| 167 |
-
|
| 168 |
-
base64_str = base64.b64encode(byte_data).decode('utf-8')
|
| 169 |
-
return base64_str
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def base64_to_bytesimg(base64_str: str) -> bytes | None:
|
| 173 |
-
"""
|
| 174 |
-
Convert base64 image string to bytes image.
|
| 175 |
-
Args:
|
| 176 |
-
base64_str: base64 image string
|
| 177 |
-
Returns:
|
| 178 |
-
bytes image or None
|
| 179 |
-
"""
|
| 180 |
-
if base64_str == '':
|
| 181 |
-
return None
|
| 182 |
-
bytes_image = base64.b64decode(base64_str)
|
| 183 |
-
return bytes_image
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def base64_to_narray(base64_str: str) -> np.ndarray | None:
|
| 187 |
-
"""
|
| 188 |
-
Convert base64 image string to numpy array.
|
| 189 |
-
Args:
|
| 190 |
-
base64_str: base64 image string
|
| 191 |
-
Returns:
|
| 192 |
-
numpy array or None
|
| 193 |
-
"""
|
| 194 |
-
if base64_str == '':
|
| 195 |
-
return None
|
| 196 |
-
bytes_image = base64.b64decode(base64_str)
|
| 197 |
-
image = np.frombuffer(bytes_image, np.uint8)
|
| 198 |
-
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/logger.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
""" A simply logger.
|
| 4 |
-
|
| 5 |
-
This module is used to log the program.
|
| 6 |
-
|
| 7 |
-
@file: logger.py
|
| 8 |
-
@author: mrhan1993
|
| 9 |
-
@update: 2024-03-22
|
| 10 |
-
"""
|
| 11 |
-
import logging
|
| 12 |
-
import os
|
| 13 |
-
import sys
|
| 14 |
-
|
| 15 |
-
try:
|
| 16 |
-
from colorlog import ColoredFormatter
|
| 17 |
-
except ImportError:
|
| 18 |
-
from fooocusapi.utils.tools import run_pip
|
| 19 |
-
run_pip(
|
| 20 |
-
command="install colorlog",
|
| 21 |
-
desc="Install colorlog for logger.",
|
| 22 |
-
live=True
|
| 23 |
-
)
|
| 24 |
-
finally:
|
| 25 |
-
from colorlog import ColoredFormatter
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
own_path = os.path.dirname(os.path.abspath(__file__))
|
| 29 |
-
log_dir = "logs"
|
| 30 |
-
default_log_path = os.path.join(own_path, '../../', log_dir)
|
| 31 |
-
|
| 32 |
-
std_formatter = ColoredFormatter(
|
| 33 |
-
fmt="%(log_color)s[%(asctime)s] %(levelname)-8s%(reset)s %(blue)s%(message)s",
|
| 34 |
-
datefmt='%Y-%m-%d %H:%M:%S',
|
| 35 |
-
reset=True,
|
| 36 |
-
log_colors={
|
| 37 |
-
'DEBUG': 'cyan',
|
| 38 |
-
'INFO': 'green',
|
| 39 |
-
'WARNING': 'yellow',
|
| 40 |
-
'ERROR': 'red',
|
| 41 |
-
'CRITICAL': 'red,bg_white',
|
| 42 |
-
},
|
| 43 |
-
secondary_log_colors={},
|
| 44 |
-
style='%'
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
file_formatter = ColoredFormatter(
|
| 48 |
-
fmt="[%(asctime)s] %(levelname)-8s%(reset)s %(message)s",
|
| 49 |
-
datefmt='%Y-%m-%d %H:%M:%S',
|
| 50 |
-
reset=True,
|
| 51 |
-
no_color=True,
|
| 52 |
-
style='%'
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class ConfigLogger:
|
| 57 |
-
"""
|
| 58 |
-
Configure logger.
|
| 59 |
-
:param log_path: log file path, better absolute path
|
| 60 |
-
:param std_format: stdout log format
|
| 61 |
-
:param file_format: file log format
|
| 62 |
-
"""
|
| 63 |
-
def __init__(self,
|
| 64 |
-
log_path: str = default_log_path,
|
| 65 |
-
std_format: ColoredFormatter = std_formatter,
|
| 66 |
-
file_format: ColoredFormatter = file_formatter) -> None:
|
| 67 |
-
self.log_path = log_path
|
| 68 |
-
self.std_format = std_format
|
| 69 |
-
self.file_format = file_format
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class Logger:
|
| 73 |
-
"""
|
| 74 |
-
A simple logger.
|
| 75 |
-
:param log_name: log name
|
| 76 |
-
:param config: config logger
|
| 77 |
-
"""
|
| 78 |
-
def __init__(self, log_name, config: ConfigLogger = ConfigLogger()):
|
| 79 |
-
log_path = config.log_path
|
| 80 |
-
err_log_path = os.path.join(str(log_path), f"{log_name}_error.log")
|
| 81 |
-
info_log_path = os.path.join(str(log_path), f"{log_name}_info.log")
|
| 82 |
-
if not os.path.exists(log_path):
|
| 83 |
-
os.makedirs(log_path, exist_ok=True)
|
| 84 |
-
|
| 85 |
-
self._file_logger = logging.getLogger(log_name)
|
| 86 |
-
self._file_logger.setLevel("INFO")
|
| 87 |
-
|
| 88 |
-
self._std_logger = logging.getLogger()
|
| 89 |
-
self._std_logger.setLevel("INFO")
|
| 90 |
-
|
| 91 |
-
# 创建一个ERROR级别的handler,将日志记录到error.log文件中
|
| 92 |
-
error_handler = logging.FileHandler(err_log_path, encoding='utf-8')
|
| 93 |
-
error_handler.setLevel(logging.ERROR)
|
| 94 |
-
|
| 95 |
-
# 创建一个INFO级别的handler,将日志记录到info.log文件中
|
| 96 |
-
info_handler = logging.FileHandler(info_log_path, encoding='utf-8')
|
| 97 |
-
info_handler.setLevel(logging.INFO)
|
| 98 |
-
|
| 99 |
-
# 创建一个 stream handler
|
| 100 |
-
stream_handler = logging.StreamHandler(sys.stdout)
|
| 101 |
-
|
| 102 |
-
error_handler.setFormatter(config.file_format)
|
| 103 |
-
info_handler.setFormatter(config.file_format)
|
| 104 |
-
stream_handler.setFormatter(config.std_format)
|
| 105 |
-
|
| 106 |
-
# 将handler添加到logger中
|
| 107 |
-
self._file_logger.addHandler(error_handler)
|
| 108 |
-
self._file_logger.addHandler(info_handler)
|
| 109 |
-
self._std_logger.addHandler(stream_handler)
|
| 110 |
-
|
| 111 |
-
def file_error(self, message):
|
| 112 |
-
"""file error log"""
|
| 113 |
-
self._file_logger.error(message)
|
| 114 |
-
|
| 115 |
-
def file_info(self, message):
|
| 116 |
-
"""file info log"""
|
| 117 |
-
self._file_logger.info(message)
|
| 118 |
-
|
| 119 |
-
def std_info(self, message):
|
| 120 |
-
"""std info log"""
|
| 121 |
-
self._std_logger.info(message)
|
| 122 |
-
|
| 123 |
-
def std_warn(self, message):
|
| 124 |
-
"""std warn log"""
|
| 125 |
-
self._std_logger.warning(message)
|
| 126 |
-
|
| 127 |
-
def std_error(self, message):
|
| 128 |
-
"""std error log"""
|
| 129 |
-
self._std_logger.error(message)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
logger = Logger(log_name="fooocus_api")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/lora_manager.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
import hashlib
|
| 2 |
-
import os
|
| 3 |
-
import requests
|
| 4 |
-
import tarfile
|
| 5 |
-
|
| 6 |
-
def _hash_url(url):
|
| 7 |
-
"""Generates a hash value for a given URL."""
|
| 8 |
-
return hashlib.md5(url.encode('utf-8')).hexdigest()
|
| 9 |
-
|
| 10 |
-
class LoraManager:
|
| 11 |
-
"""
|
| 12 |
-
Manager loras from url
|
| 13 |
-
"""
|
| 14 |
-
def __init__(self):
|
| 15 |
-
self.cache_dir = os.path.join(
|
| 16 |
-
os.path.dirname(os.path.realpath(__file__)),
|
| 17 |
-
'../../',
|
| 18 |
-
'repositories/Fooocus/models/loras')
|
| 19 |
-
|
| 20 |
-
def _download_lora(self, url):
|
| 21 |
-
"""
|
| 22 |
-
Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file.
|
| 23 |
-
"""
|
| 24 |
-
url_hash = _hash_url(url)
|
| 25 |
-
file_ext = url.split('.')[-1]
|
| 26 |
-
filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}")
|
| 27 |
-
|
| 28 |
-
if not os.path.exists(filepath):
|
| 29 |
-
print(f"Start download for: {url}")
|
| 30 |
-
|
| 31 |
-
try:
|
| 32 |
-
response = requests.get(url, timeout=10, stream=True)
|
| 33 |
-
response.raise_for_status()
|
| 34 |
-
with open(filepath, 'wb') as f:
|
| 35 |
-
for chunk in response.iter_content(chunk_size=8192):
|
| 36 |
-
f.write(chunk)
|
| 37 |
-
|
| 38 |
-
if file_ext == "tar":
|
| 39 |
-
print("Extracting the tar file...")
|
| 40 |
-
with tarfile.open(filepath, 'r:*') as tar:
|
| 41 |
-
tar.extractall(path=self.cache_dir)
|
| 42 |
-
print("Extraction completed.")
|
| 43 |
-
return self._find_safetensors_file(self.cache_dir)
|
| 44 |
-
|
| 45 |
-
print(f"Download successfully, saved as {filepath}")
|
| 46 |
-
except Exception as e:
|
| 47 |
-
raise Exception(f"Error downloading {url}: {e}") from e
|
| 48 |
-
|
| 49 |
-
else:
|
| 50 |
-
print(f"LoRa already downloaded {url}")
|
| 51 |
-
|
| 52 |
-
return filepath
|
| 53 |
-
|
| 54 |
-
def _find_safetensors_file(self, directory):
|
| 55 |
-
"""
|
| 56 |
-
Finds the first .safetensors file in the specified directory.
|
| 57 |
-
"""
|
| 58 |
-
print("Searching for .safetensors file.")
|
| 59 |
-
for root, dirs, files in os.walk(directory):
|
| 60 |
-
for file in files:
|
| 61 |
-
if file.endswith('.safetensors'):
|
| 62 |
-
return os.path.join(root, file)
|
| 63 |
-
raise FileNotFoundError("No .safetensors file found in the extracted files.")
|
| 64 |
-
|
| 65 |
-
def check(self, urls):
|
| 66 |
-
"""Manages the specified LoRAs: downloads missing ones and returns their file names."""
|
| 67 |
-
paths = []
|
| 68 |
-
for url in urls:
|
| 69 |
-
path = self._download_lora(url)
|
| 70 |
-
paths.append(path)
|
| 71 |
-
return paths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/model_loader.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
Download models from url
|
| 5 |
-
|
| 6 |
-
@file: model_loader.py
|
| 7 |
-
@author: Konie
|
| 8 |
-
@update: 2024-03-22
|
| 9 |
-
"""
|
| 10 |
-
from modules.model_loader import load_file_from_url
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def download_models():
|
| 14 |
-
"""
|
| 15 |
-
Download models from config
|
| 16 |
-
"""
|
| 17 |
-
vae_approx_filenames = [
|
| 18 |
-
('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
|
| 19 |
-
('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
|
| 20 |
-
('xl-to-v1_interposer-v3.1.safetensors', 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
|
| 21 |
-
]
|
| 22 |
-
|
| 23 |
-
from modules.config import (
|
| 24 |
-
paths_checkpoints as modelfile_path,
|
| 25 |
-
paths_loras as lorafile_path,
|
| 26 |
-
path_vae_approx as vae_approx_path,
|
| 27 |
-
path_fooocus_expansion as fooocus_expansion_path,
|
| 28 |
-
path_embeddings as embeddings_path,
|
| 29 |
-
checkpoint_downloads,
|
| 30 |
-
embeddings_downloads,
|
| 31 |
-
lora_downloads)
|
| 32 |
-
|
| 33 |
-
for file_name, url in checkpoint_downloads.items():
|
| 34 |
-
load_file_from_url(url=url, model_dir=modelfile_path[0], file_name=file_name)
|
| 35 |
-
for file_name, url in embeddings_downloads.items():
|
| 36 |
-
load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name)
|
| 37 |
-
for file_name, url in lora_downloads.items():
|
| 38 |
-
load_file_from_url(url=url, model_dir=lorafile_path[0], file_name=file_name)
|
| 39 |
-
for file_name, url in vae_approx_filenames:
|
| 40 |
-
load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
|
| 41 |
-
|
| 42 |
-
load_file_from_url(
|
| 43 |
-
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
|
| 44 |
-
model_dir=fooocus_expansion_path,
|
| 45 |
-
file_name='pytorch_model.bin'
|
| 46 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/utils/tools.py
DELETED
|
@@ -1,159 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
-
""" Some tools
|
| 4 |
-
|
| 5 |
-
@file: tools.py
|
| 6 |
-
@author: Konie
|
| 7 |
-
@update: 2024-03-22
|
| 8 |
-
"""
|
| 9 |
-
# pylint: disable=line-too-long
|
| 10 |
-
# pylint: disable=broad-exception-caught
|
| 11 |
-
import os
|
| 12 |
-
import sys
|
| 13 |
-
import re
|
| 14 |
-
import subprocess
|
| 15 |
-
from importlib.util import find_spec
|
| 16 |
-
from importlib import metadata
|
| 17 |
-
from packaging import version
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
PYTHON_EXEC = sys.executable
|
| 21 |
-
INDEX_URL = os.environ.get('INDEX_URL', "")
|
| 22 |
-
PATTERN = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
|
| 26 |
-
def run_command(command: str,
|
| 27 |
-
desc: str = None,
|
| 28 |
-
error_desc: str = None,
|
| 29 |
-
custom_env: str = None,
|
| 30 |
-
live: bool = True) -> str:
|
| 31 |
-
"""
|
| 32 |
-
Run a command and return the output
|
| 33 |
-
Args:
|
| 34 |
-
command: Command to run
|
| 35 |
-
desc: Description of the command
|
| 36 |
-
error_desc: Description of the error
|
| 37 |
-
custom_env: Custom environment variables
|
| 38 |
-
live: Whether to print the output
|
| 39 |
-
Returns:
|
| 40 |
-
The output of the command
|
| 41 |
-
"""
|
| 42 |
-
if desc is not None:
|
| 43 |
-
print(desc)
|
| 44 |
-
|
| 45 |
-
run_kwargs = {
|
| 46 |
-
"args": command,
|
| 47 |
-
"shell": True,
|
| 48 |
-
"env": os.environ if custom_env is None else custom_env,
|
| 49 |
-
"encoding": 'utf8',
|
| 50 |
-
"errors": 'ignore'
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
if not live:
|
| 54 |
-
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
| 55 |
-
|
| 56 |
-
result = subprocess.run(check=False, **run_kwargs)
|
| 57 |
-
|
| 58 |
-
if result.returncode != 0:
|
| 59 |
-
error_bits = [
|
| 60 |
-
f"{error_desc or 'Error running command'}.",
|
| 61 |
-
f"Command: {command}",
|
| 62 |
-
f"Error code: {result.returncode}",
|
| 63 |
-
]
|
| 64 |
-
if result.stdout:
|
| 65 |
-
error_bits.append(f"stdout: {result.stdout}")
|
| 66 |
-
if result.stderr:
|
| 67 |
-
error_bits.append(f"stderr: {result.stderr}")
|
| 68 |
-
raise RuntimeError("\n".join(error_bits))
|
| 69 |
-
|
| 70 |
-
return result.stdout or ""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# This function was copied from [Fooocus](https://github.com/lllyasviel/Fooocus) repository.
|
| 74 |
-
def run_pip(command, desc=None, live=True):
|
| 75 |
-
"""
|
| 76 |
-
Run a pip command
|
| 77 |
-
Args:
|
| 78 |
-
command: Command to run
|
| 79 |
-
desc: Description of the command
|
| 80 |
-
live: Whether to print the output
|
| 81 |
-
Returns:
|
| 82 |
-
The output of the command
|
| 83 |
-
"""
|
| 84 |
-
try:
|
| 85 |
-
index_url_line = f' --index-url {INDEX_URL}' if INDEX_URL != '' else ''
|
| 86 |
-
return run_command(
|
| 87 |
-
command=f'"{PYTHON_EXEC}" -m pip {command} --prefer-binary{index_url_line}',
|
| 88 |
-
desc=f"Installing {desc}",
|
| 89 |
-
error_desc=f"Couldn't install {desc}",
|
| 90 |
-
live=live
|
| 91 |
-
)
|
| 92 |
-
except Exception as e:
|
| 93 |
-
print(f'CMD Failed {command}: {e}')
|
| 94 |
-
return None
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def is_installed(package: str) -> bool:
|
| 98 |
-
"""
|
| 99 |
-
Check if a package is installed
|
| 100 |
-
Args:
|
| 101 |
-
package: Package name
|
| 102 |
-
Returns:
|
| 103 |
-
Whether the package is installed
|
| 104 |
-
"""
|
| 105 |
-
try:
|
| 106 |
-
spec = find_spec(package)
|
| 107 |
-
except ModuleNotFoundError:
|
| 108 |
-
return False
|
| 109 |
-
|
| 110 |
-
return spec is not None
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def check_torch_cuda() -> bool:
|
| 114 |
-
"""
|
| 115 |
-
Check if torch and CUDA is available
|
| 116 |
-
Returns:
|
| 117 |
-
Whether CUDA is available
|
| 118 |
-
"""
|
| 119 |
-
try:
|
| 120 |
-
import torch
|
| 121 |
-
return torch.cuda.is_available()
|
| 122 |
-
except ImportError:
|
| 123 |
-
return False
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def requirements_check(requirements_file: str = 'requirements.txt',
|
| 127 |
-
pattern: re.Pattern = PATTERN) -> bool:
|
| 128 |
-
"""
|
| 129 |
-
Check if the requirements file is satisfied
|
| 130 |
-
Args:
|
| 131 |
-
requirements_file: Path to the requirements file
|
| 132 |
-
pattern: Pattern to match the requirements
|
| 133 |
-
Returns:
|
| 134 |
-
Whether the requirements file is satisfied
|
| 135 |
-
"""
|
| 136 |
-
with open(requirements_file, "r", encoding="utf8") as file:
|
| 137 |
-
for line in file:
|
| 138 |
-
if line.strip() == "":
|
| 139 |
-
continue
|
| 140 |
-
|
| 141 |
-
m = re.match(pattern, line)
|
| 142 |
-
if m is None:
|
| 143 |
-
return False
|
| 144 |
-
|
| 145 |
-
package = m.group(1).strip()
|
| 146 |
-
version_required = (m.group(2) or "").strip()
|
| 147 |
-
|
| 148 |
-
if version_required == "":
|
| 149 |
-
continue
|
| 150 |
-
|
| 151 |
-
try:
|
| 152 |
-
version_installed = metadata.version(package)
|
| 153 |
-
except Exception:
|
| 154 |
-
return False
|
| 155 |
-
|
| 156 |
-
if version.parse(version_required) != version.parse(version_installed):
|
| 157 |
-
return False
|
| 158 |
-
|
| 159 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fooocusapi/worker.py
DELETED
|
@@ -1,1044 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Worker, modify from https://github.com/lllyasviel/Fooocus/blob/main/modules/async_worker.py
|
| 3 |
-
"""
|
| 4 |
-
import copy
|
| 5 |
-
import os
|
| 6 |
-
import random
|
| 7 |
-
import time
|
| 8 |
-
from typing import List
|
| 9 |
-
import logging
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from fooocusapi.models.common.image_meta import image_parse
|
| 14 |
-
from modules.patch import PatchSettings, patch_settings, patch_all
|
| 15 |
-
from modules.flags import Performance
|
| 16 |
-
|
| 17 |
-
from fooocusapi.utils.file_utils import save_output_file
|
| 18 |
-
from fooocusapi.models.common.task import (
|
| 19 |
-
GenerationFinishReason,
|
| 20 |
-
ImageGenerationResult
|
| 21 |
-
)
|
| 22 |
-
from fooocusapi.utils.logger import logger
|
| 23 |
-
from fooocusapi.task_queue import (
|
| 24 |
-
QueueTask,
|
| 25 |
-
TaskQueue,
|
| 26 |
-
TaskOutputs
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
patch_all()
|
| 30 |
-
|
| 31 |
-
worker_queue: TaskQueue | None = None
|
| 32 |
-
last_model_name = None
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def process_stop():
|
| 36 |
-
"""Stop process"""
|
| 37 |
-
import ldm_patched.modules.model_management
|
| 38 |
-
ldm_patched.modules.model_management.interrupt_current_processing()
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
@torch.no_grad()
|
| 42 |
-
@torch.inference_mode()
|
| 43 |
-
def task_schedule_loop():
|
| 44 |
-
"""Task schedule loop"""
|
| 45 |
-
while True:
|
| 46 |
-
if len(worker_queue.queue) == 0:
|
| 47 |
-
time.sleep(0.05)
|
| 48 |
-
continue
|
| 49 |
-
|
| 50 |
-
current_task = worker_queue.queue[0]
|
| 51 |
-
if current_task.start_mills == 0:
|
| 52 |
-
process_generate(current_task)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
@torch.no_grad()
|
| 56 |
-
@torch.inference_mode()
|
| 57 |
-
def blocking_get_task_result(job_id: str) -> List[ImageGenerationResult]:
|
| 58 |
-
"""
|
| 59 |
-
Get task result, when async_task is false
|
| 60 |
-
:param job_id:
|
| 61 |
-
:return:
|
| 62 |
-
"""
|
| 63 |
-
waiting_sleep_steps: int = 0
|
| 64 |
-
waiting_start_time = time.perf_counter()
|
| 65 |
-
while not worker_queue.is_task_finished(job_id):
|
| 66 |
-
if waiting_sleep_steps == 0:
|
| 67 |
-
logger.std_info(f"[Task Queue] Waiting for task finished, job_id={job_id}")
|
| 68 |
-
delay = 0.05
|
| 69 |
-
time.sleep(delay)
|
| 70 |
-
waiting_sleep_steps += 1
|
| 71 |
-
if waiting_sleep_steps % int(10 / delay) == 0:
|
| 72 |
-
waiting_time = time.perf_counter() - waiting_start_time
|
| 73 |
-
logger.std_info(f"[Task Queue] Already waiting for {round(waiting_time, 1)} seconds, job_id={job_id}")
|
| 74 |
-
|
| 75 |
-
task = worker_queue.get_task(job_id, True)
|
| 76 |
-
return task.task_result
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
@torch.no_grad()
|
| 80 |
-
@torch.inference_mode()
|
| 81 |
-
def process_generate(async_task: QueueTask):
|
| 82 |
-
"""Generate image"""
|
| 83 |
-
try:
|
| 84 |
-
import modules.default_pipeline as pipeline
|
| 85 |
-
except Exception as e:
|
| 86 |
-
logger.std_error(f'[Task Queue] Import default pipeline error: {e}')
|
| 87 |
-
if not async_task.is_finished:
|
| 88 |
-
worker_queue.finish_task(async_task.job_id)
|
| 89 |
-
async_task.set_result([], True, str(e))
|
| 90 |
-
logger.std_error(f"[Task Queue] Finish task with error, seq={async_task.job_id}")
|
| 91 |
-
return []
|
| 92 |
-
|
| 93 |
-
import modules.flags as flags
|
| 94 |
-
import modules.core as core
|
| 95 |
-
import modules.inpaint_worker as inpaint_worker
|
| 96 |
-
import modules.config as config
|
| 97 |
-
import modules.constants as constants
|
| 98 |
-
import extras.preprocessors as preprocessors
|
| 99 |
-
import extras.ip_adapter as ip_adapter
|
| 100 |
-
import extras.face_crop as face_crop
|
| 101 |
-
import ldm_patched.modules.model_management as model_management
|
| 102 |
-
from modules.util import (
|
| 103 |
-
remove_empty_str, HWC3, resize_image,
|
| 104 |
-
get_image_shape_ceil, set_image_shape_ceil,
|
| 105 |
-
get_shape_ceil, resample_image, erode_or_dilate,
|
| 106 |
-
get_enabled_loras, parse_lora_references_from_prompt, apply_wildcards,
|
| 107 |
-
remove_performance_lora
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
from modules.upscaler import perform_upscale
|
| 111 |
-
from extras.expansion import safe_str
|
| 112 |
-
from extras.censor import default_censor
|
| 113 |
-
from modules.sdxl_styles import (
|
| 114 |
-
apply_style, get_random_style,
|
| 115 |
-
fooocus_expansion, apply_arrays, random_style_name
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
pid = os.getpid()
|
| 119 |
-
|
| 120 |
-
outputs = TaskOutputs(async_task)
|
| 121 |
-
results = []
|
| 122 |
-
|
| 123 |
-
def refresh_seed(seed_string: int | str | None) -> int:
|
| 124 |
-
"""
|
| 125 |
-
Refresh and check seed number.
|
| 126 |
-
:params seed_string: seed, str or int. None means random
|
| 127 |
-
:return: seed number
|
| 128 |
-
"""
|
| 129 |
-
if seed_string is None or seed_string == -1:
|
| 130 |
-
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
| 131 |
-
|
| 132 |
-
try:
|
| 133 |
-
seed_value = int(seed_string)
|
| 134 |
-
if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
|
| 135 |
-
return seed_value
|
| 136 |
-
except ValueError:
|
| 137 |
-
pass
|
| 138 |
-
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
| 139 |
-
|
| 140 |
-
def progressbar(_, number, text):
|
| 141 |
-
"""progress bar"""
|
| 142 |
-
logger.std_info(f'[Fooocus] {text}')
|
| 143 |
-
outputs.append(['preview', (number, text, None)])
|
| 144 |
-
|
| 145 |
-
def yield_result(_, images, tasks, extension='png',
|
| 146 |
-
blockout_nsfw=False, censor=True):
|
| 147 |
-
"""
|
| 148 |
-
Yield result
|
| 149 |
-
:param _: async task object
|
| 150 |
-
:param images: list for generated image
|
| 151 |
-
:param tasks: the image was generated one by one, when image number is not one, it will be a task list
|
| 152 |
-
:param extension: extension for saved image
|
| 153 |
-
:param blockout_nsfw: blockout nsfw image
|
| 154 |
-
:param censor: censor image
|
| 155 |
-
:return:
|
| 156 |
-
"""
|
| 157 |
-
if not isinstance(images, list):
|
| 158 |
-
images = [images]
|
| 159 |
-
|
| 160 |
-
if censor and (config.default_black_out_nsfw or black_out_nsfw):
|
| 161 |
-
images = default_censor(images)
|
| 162 |
-
|
| 163 |
-
results = []
|
| 164 |
-
for index, im in enumerate(images):
|
| 165 |
-
if async_task.req_param.save_name == '':
|
| 166 |
-
image_name = f"{async_task.job_id}-{str(index)}"
|
| 167 |
-
else:
|
| 168 |
-
image_name = f"{async_task.req_param.save_name}-{str(index)}"
|
| 169 |
-
if len(tasks) == 0:
|
| 170 |
-
img_seed = -1
|
| 171 |
-
img_meta = {}
|
| 172 |
-
else:
|
| 173 |
-
img_seed = tasks[index]['task_seed']
|
| 174 |
-
img_meta = image_parse(
|
| 175 |
-
async_tak=async_task,
|
| 176 |
-
task=tasks[index])
|
| 177 |
-
img_filename = save_output_file(
|
| 178 |
-
img=im,
|
| 179 |
-
image_name=image_name,
|
| 180 |
-
image_meta=img_meta,
|
| 181 |
-
extension=extension)
|
| 182 |
-
results.append(ImageGenerationResult(
|
| 183 |
-
im=img_filename,
|
| 184 |
-
seed=str(img_seed),
|
| 185 |
-
finish_reason=GenerationFinishReason.success))
|
| 186 |
-
async_task.set_result(results, False)
|
| 187 |
-
worker_queue.finish_task(async_task.job_id)
|
| 188 |
-
logger.std_info(f"[Task Queue] Finish task, job_id={async_task.job_id}")
|
| 189 |
-
|
| 190 |
-
outputs.append(['results', images])
|
| 191 |
-
pipeline.prepare_text_encoder(async_call=True)
|
| 192 |
-
|
| 193 |
-
try:
|
| 194 |
-
logger.std_info(f"[Task Queue] Task queue start task, job_id={async_task.job_id}")
|
| 195 |
-
# clear memory
|
| 196 |
-
global last_model_name
|
| 197 |
-
|
| 198 |
-
if last_model_name is None:
|
| 199 |
-
last_model_name = async_task.req_param.base_model_name
|
| 200 |
-
if last_model_name != async_task.req_param.base_model_name:
|
| 201 |
-
model_management.cleanup_models() # key1
|
| 202 |
-
model_management.unload_all_models()
|
| 203 |
-
model_management.soft_empty_cache() # key2
|
| 204 |
-
last_model_name = async_task.req_param.base_model_name
|
| 205 |
-
|
| 206 |
-
worker_queue.start_task(async_task.job_id)
|
| 207 |
-
|
| 208 |
-
execution_start_time = time.perf_counter()
|
| 209 |
-
|
| 210 |
-
# Transform parameters
|
| 211 |
-
params = async_task.req_param
|
| 212 |
-
prompt = params.prompt
|
| 213 |
-
negative_prompt = params.negative_prompt
|
| 214 |
-
style_selections = params.style_selections
|
| 215 |
-
performance_selection = Performance(params.performance_selection)
|
| 216 |
-
aspect_ratios_selection = params.aspect_ratios_selection
|
| 217 |
-
image_number = params.image_number
|
| 218 |
-
save_metadata_to_images = params.save_meta
|
| 219 |
-
metadata_scheme = params.meta_scheme
|
| 220 |
-
save_extension = params.save_extension
|
| 221 |
-
save_name = params.save_name
|
| 222 |
-
image_seed = refresh_seed(params.image_seed)
|
| 223 |
-
read_wildcards_in_order = False
|
| 224 |
-
sharpness = params.sharpness
|
| 225 |
-
guidance_scale = params.guidance_scale
|
| 226 |
-
base_model_name = params.base_model_name
|
| 227 |
-
refiner_model_name = params.refiner_model_name
|
| 228 |
-
refiner_switch = params.refiner_switch
|
| 229 |
-
loras = params.loras
|
| 230 |
-
input_image_checkbox = params.uov_input_image is not None or params.inpaint_input_image is not None or len(params.image_prompts) > 0
|
| 231 |
-
current_tab = 'uov' if params.uov_method != flags.disabled else 'ip' if len(params.image_prompts) > 0 else 'inpaint' if params.inpaint_input_image is not None else None
|
| 232 |
-
uov_method = params.uov_method
|
| 233 |
-
upscale_value = params.upscale_value
|
| 234 |
-
uov_input_image = params.uov_input_image
|
| 235 |
-
outpaint_selections = params.outpaint_selections
|
| 236 |
-
outpaint_distance_left = params.outpaint_distance_left
|
| 237 |
-
outpaint_distance_top = params.outpaint_distance_top
|
| 238 |
-
outpaint_distance_right = params.outpaint_distance_right
|
| 239 |
-
outpaint_distance_bottom = params.outpaint_distance_bottom
|
| 240 |
-
inpaint_input_image = params.inpaint_input_image
|
| 241 |
-
inpaint_additional_prompt = '' if params.inpaint_additional_prompt is None else params.inpaint_additional_prompt
|
| 242 |
-
inpaint_mask_image_upload = None
|
| 243 |
-
|
| 244 |
-
adp = params.advanced_params
|
| 245 |
-
disable_preview = adp.disable_preview
|
| 246 |
-
disable_intermediate_results = adp.disable_intermediate_results
|
| 247 |
-
disable_seed_increment = adp.disable_seed_increment
|
| 248 |
-
adm_scaler_positive = adp.adm_scaler_positive
|
| 249 |
-
adm_scaler_negative = adp.adm_scaler_negative
|
| 250 |
-
adm_scaler_end = adp.adm_scaler_end
|
| 251 |
-
adaptive_cfg = adp.adaptive_cfg
|
| 252 |
-
sampler_name = adp.sampler_name
|
| 253 |
-
scheduler_name = adp.scheduler_name
|
| 254 |
-
overwrite_step = adp.overwrite_step
|
| 255 |
-
overwrite_switch = adp.overwrite_switch
|
| 256 |
-
overwrite_width = adp.overwrite_width
|
| 257 |
-
overwrite_height = adp.overwrite_height
|
| 258 |
-
overwrite_vary_strength = adp.overwrite_vary_strength
|
| 259 |
-
overwrite_upscale_strength = adp.overwrite_upscale_strength
|
| 260 |
-
mixing_image_prompt_and_vary_upscale = adp.mixing_image_prompt_and_vary_upscale
|
| 261 |
-
mixing_image_prompt_and_inpaint = adp.mixing_image_prompt_and_inpaint
|
| 262 |
-
debugging_cn_preprocessor = adp.debugging_cn_preprocessor
|
| 263 |
-
skipping_cn_preprocessor = adp.skipping_cn_preprocessor
|
| 264 |
-
canny_low_threshold = adp.canny_low_threshold
|
| 265 |
-
canny_high_threshold = adp.canny_high_threshold
|
| 266 |
-
refiner_swap_method = adp.refiner_swap_method
|
| 267 |
-
controlnet_softness = adp.controlnet_softness
|
| 268 |
-
freeu_enabled = adp.freeu_enabled
|
| 269 |
-
freeu_b1 = adp.freeu_b1
|
| 270 |
-
freeu_b2 = adp.freeu_b2
|
| 271 |
-
freeu_s1 = adp.freeu_s1
|
| 272 |
-
freeu_s2 = adp.freeu_s2
|
| 273 |
-
debugging_inpaint_preprocessor = adp.debugging_inpaint_preprocessor
|
| 274 |
-
inpaint_disable_initial_latent = adp.inpaint_disable_initial_latent
|
| 275 |
-
inpaint_engine = adp.inpaint_engine
|
| 276 |
-
inpaint_strength = adp.inpaint_strength
|
| 277 |
-
inpaint_respective_field = adp.inpaint_respective_field
|
| 278 |
-
inpaint_mask_upload_checkbox = adp.inpaint_mask_upload_checkbox
|
| 279 |
-
invert_mask_checkbox = adp.invert_mask_checkbox
|
| 280 |
-
inpaint_erode_or_dilate = adp.inpaint_erode_or_dilate
|
| 281 |
-
black_out_nsfw = adp.black_out_nsfw
|
| 282 |
-
vae_name = adp.vae_name
|
| 283 |
-
clip_skip = adp.clip_skip
|
| 284 |
-
|
| 285 |
-
cn_tasks = {x: [] for x in flags.ip_list}
|
| 286 |
-
for img_prompt in params.image_prompts:
|
| 287 |
-
cn_img, cn_stop, cn_weight, cn_type = img_prompt
|
| 288 |
-
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
| 289 |
-
|
| 290 |
-
if inpaint_input_image is not None and inpaint_input_image['image'] is not None:
|
| 291 |
-
inpaint_image_size = inpaint_input_image['image'].shape[:2]
|
| 292 |
-
if inpaint_input_image['mask'] is None:
|
| 293 |
-
inpaint_input_image['mask'] = np.zeros(inpaint_image_size, dtype=np.uint8)
|
| 294 |
-
else:
|
| 295 |
-
inpaint_mask_upload_checkbox = True
|
| 296 |
-
|
| 297 |
-
inpaint_input_image['mask'] = HWC3(inpaint_input_image['mask'])
|
| 298 |
-
inpaint_mask_image_upload = inpaint_input_image['mask']
|
| 299 |
-
|
| 300 |
-
# Fooocus async_worker.py code start
|
| 301 |
-
|
| 302 |
-
outpaint_selections = [o.lower() for o in outpaint_selections]
|
| 303 |
-
base_model_additional_loras = []
|
| 304 |
-
raw_style_selections = copy.deepcopy(style_selections)
|
| 305 |
-
uov_method = uov_method.lower()
|
| 306 |
-
|
| 307 |
-
if fooocus_expansion in style_selections:
|
| 308 |
-
use_expansion = True
|
| 309 |
-
style_selections.remove(fooocus_expansion)
|
| 310 |
-
else:
|
| 311 |
-
use_expansion = False
|
| 312 |
-
|
| 313 |
-
use_style = len(style_selections) > 0
|
| 314 |
-
|
| 315 |
-
if base_model_name == refiner_model_name:
|
| 316 |
-
logger.std_warn('[Fooocus] Refiner disabled because base model and refiner are same.')
|
| 317 |
-
refiner_model_name = 'None'
|
| 318 |
-
|
| 319 |
-
steps = performance_selection.steps()
|
| 320 |
-
|
| 321 |
-
performance_loras = []
|
| 322 |
-
|
| 323 |
-
if performance_selection == Performance.EXTREME_SPEED:
|
| 324 |
-
logger.std_warn('[Fooocus] Enter LCM mode.')
|
| 325 |
-
progressbar(async_task, 1, 'Downloading LCM components ...')
|
| 326 |
-
performance_loras += [(config.downloading_sdxl_lcm_lora(), 1.0)]
|
| 327 |
-
|
| 328 |
-
if refiner_model_name != 'None':
|
| 329 |
-
logger.std_info('[Fooocus] Refiner disabled in LCM mode.')
|
| 330 |
-
|
| 331 |
-
refiner_model_name = 'None'
|
| 332 |
-
sampler_name = 'lcm'
|
| 333 |
-
scheduler_name = 'lcm'
|
| 334 |
-
sharpness = 0.0
|
| 335 |
-
guidance_scale = 1.0
|
| 336 |
-
adaptive_cfg = 1.0
|
| 337 |
-
refiner_switch = 1.0
|
| 338 |
-
adm_scaler_positive = 1.0
|
| 339 |
-
adm_scaler_negative = 1.0
|
| 340 |
-
adm_scaler_end = 0.0
|
| 341 |
-
|
| 342 |
-
elif performance_selection == Performance.LIGHTNING:
|
| 343 |
-
logger.std_info('[Fooocus] Enter Lightning mode.')
|
| 344 |
-
progressbar(async_task, 1, 'Downloading Lightning components ...')
|
| 345 |
-
performance_loras += [(config.downloading_sdxl_lightning_lora(), 1.0)]
|
| 346 |
-
|
| 347 |
-
if refiner_model_name != 'None':
|
| 348 |
-
logger.std_info('[Fooocus] Refiner disabled in Lightning mode.')
|
| 349 |
-
|
| 350 |
-
refiner_model_name = 'None'
|
| 351 |
-
sampler_name = 'euler'
|
| 352 |
-
scheduler_name = 'sgm_uniform'
|
| 353 |
-
sharpness = 0.0
|
| 354 |
-
guidance_scale = 1.0
|
| 355 |
-
adaptive_cfg = 1.0
|
| 356 |
-
refiner_switch = 1.0
|
| 357 |
-
adm_scaler_positive = 1.0
|
| 358 |
-
adm_scaler_negative = 1.0
|
| 359 |
-
adm_scaler_end = 0.0
|
| 360 |
-
|
| 361 |
-
elif performance_selection == Performance.HYPER_SD:
|
| 362 |
-
print('Enter Hyper-SD mode.')
|
| 363 |
-
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
|
| 364 |
-
performance_loras += [(config.downloading_sdxl_hyper_sd_lora(), 0.8)]
|
| 365 |
-
|
| 366 |
-
if refiner_model_name != 'None':
|
| 367 |
-
logger.std_info('[Fooocus] Refiner disabled in Hyper-SD mode.')
|
| 368 |
-
|
| 369 |
-
refiner_model_name = 'None'
|
| 370 |
-
sampler_name = 'dpmpp_sde_gpu'
|
| 371 |
-
scheduler_name = 'karras'
|
| 372 |
-
sharpness = 0.0
|
| 373 |
-
guidance_scale = 1.0
|
| 374 |
-
adaptive_cfg = 1.0
|
| 375 |
-
refiner_switch = 1.0
|
| 376 |
-
adm_scaler_positive = 1.0
|
| 377 |
-
adm_scaler_negative = 1.0
|
| 378 |
-
adm_scaler_end = 0.0
|
| 379 |
-
|
| 380 |
-
logger.std_info(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
|
| 381 |
-
logger.std_info(f'[Parameters] CLIP Skip = {clip_skip}')
|
| 382 |
-
logger.std_info(f'[Parameters] Sharpness = {sharpness}')
|
| 383 |
-
logger.std_info(f'[Parameters] ControlNet Softness = {controlnet_softness}')
|
| 384 |
-
logger.std_info(f'[Parameters] ADM Scale = '
|
| 385 |
-
f'{adm_scaler_positive} : '
|
| 386 |
-
f'{adm_scaler_negative} : '
|
| 387 |
-
f'{adm_scaler_end}')
|
| 388 |
-
|
| 389 |
-
patch_settings[pid] = PatchSettings(
|
| 390 |
-
sharpness,
|
| 391 |
-
adm_scaler_end,
|
| 392 |
-
adm_scaler_positive,
|
| 393 |
-
adm_scaler_negative,
|
| 394 |
-
controlnet_softness,
|
| 395 |
-
adaptive_cfg
|
| 396 |
-
)
|
| 397 |
-
|
| 398 |
-
cfg_scale = float(guidance_scale)
|
| 399 |
-
logger.std_info(f'[Parameters] CFG = {cfg_scale}')
|
| 400 |
-
|
| 401 |
-
initial_latent = None
|
| 402 |
-
denoising_strength = 1.0
|
| 403 |
-
tiled = False
|
| 404 |
-
|
| 405 |
-
width, height = aspect_ratios_selection.replace('×', ' ').replace('*', ' ').split(' ')[:2]
|
| 406 |
-
width, height = int(width), int(height)
|
| 407 |
-
|
| 408 |
-
skip_prompt_processing = False
|
| 409 |
-
|
| 410 |
-
inpaint_worker.current_task = None
|
| 411 |
-
inpaint_parameterized = inpaint_engine != 'None'
|
| 412 |
-
inpaint_image = None
|
| 413 |
-
inpaint_mask = None
|
| 414 |
-
inpaint_head_model_path = None
|
| 415 |
-
|
| 416 |
-
use_synthetic_refiner = False
|
| 417 |
-
|
| 418 |
-
controlnet_canny_path = None
|
| 419 |
-
controlnet_cpds_path = None
|
| 420 |
-
clip_vision_path, ip_negative_path, ip_adapter_path, ip_adapter_face_path = None, None, None, None
|
| 421 |
-
|
| 422 |
-
seed = int(image_seed)
|
| 423 |
-
logger.std_info(f'[Parameters] Seed = {seed}')
|
| 424 |
-
|
| 425 |
-
goals = []
|
| 426 |
-
tasks = []
|
| 427 |
-
|
| 428 |
-
if input_image_checkbox:
|
| 429 |
-
if (current_tab == 'uov' or (
|
| 430 |
-
current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
|
| 431 |
-
and uov_method != flags.disabled and uov_input_image is not None:
|
| 432 |
-
uov_input_image = HWC3(uov_input_image)
|
| 433 |
-
if 'vary' in uov_method:
|
| 434 |
-
goals.append('vary')
|
| 435 |
-
elif 'upscale' in uov_method:
|
| 436 |
-
goals.append('upscale')
|
| 437 |
-
if 'fast' in uov_method:
|
| 438 |
-
skip_prompt_processing = True
|
| 439 |
-
else:
|
| 440 |
-
steps = performance_selection.steps_uov()
|
| 441 |
-
|
| 442 |
-
progressbar(async_task, 1, 'Downloading upscale models ...')
|
| 443 |
-
config.downloading_upscale_model()
|
| 444 |
-
if (current_tab == 'inpaint' or (
|
| 445 |
-
current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
|
| 446 |
-
and isinstance(inpaint_input_image, dict):
|
| 447 |
-
inpaint_image = inpaint_input_image['image']
|
| 448 |
-
inpaint_mask = inpaint_input_image['mask'][:, :, 0]
|
| 449 |
-
|
| 450 |
-
if inpaint_mask_upload_checkbox:
|
| 451 |
-
if isinstance(inpaint_mask_image_upload, np.ndarray):
|
| 452 |
-
if inpaint_mask_image_upload.ndim == 3:
|
| 453 |
-
H, W, C = inpaint_image.shape
|
| 454 |
-
inpaint_mask_image_upload = resample_image(inpaint_mask_image_upload, width=W, height=H)
|
| 455 |
-
inpaint_mask_image_upload = np.mean(inpaint_mask_image_upload, axis=2)
|
| 456 |
-
inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
|
| 457 |
-
inpaint_mask = np.maximum(np.zeros(shape=(H, W), dtype=np.uint8), inpaint_mask_image_upload)
|
| 458 |
-
|
| 459 |
-
if int(inpaint_erode_or_dilate) != 0:
|
| 460 |
-
inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
|
| 461 |
-
|
| 462 |
-
if invert_mask_checkbox:
|
| 463 |
-
inpaint_mask = 255 - inpaint_mask
|
| 464 |
-
|
| 465 |
-
inpaint_image = HWC3(inpaint_image)
|
| 466 |
-
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
| 467 |
-
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
|
| 468 |
-
progressbar(async_task, 1, 'Downloading upscale models ...')
|
| 469 |
-
config.downloading_upscale_model()
|
| 470 |
-
if inpaint_parameterized:
|
| 471 |
-
progressbar(async_task, 1, 'Downloading inpainter ...')
|
| 472 |
-
inpaint_head_model_path, inpaint_patch_model_path = config.downloading_inpaint_models(
|
| 473 |
-
inpaint_engine)
|
| 474 |
-
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
| 475 |
-
logger.std_info(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
| 476 |
-
if refiner_model_name == 'None':
|
| 477 |
-
use_synthetic_refiner = True
|
| 478 |
-
refiner_switch = 0.8
|
| 479 |
-
else:
|
| 480 |
-
inpaint_head_model_path, inpaint_patch_model_path = None, None
|
| 481 |
-
logger.std_info('[Inpaint] Parameterized inpaint is disabled.')
|
| 482 |
-
if inpaint_additional_prompt != '':
|
| 483 |
-
if prompt == '':
|
| 484 |
-
prompt = inpaint_additional_prompt
|
| 485 |
-
else:
|
| 486 |
-
prompt = inpaint_additional_prompt + '\n' + prompt
|
| 487 |
-
goals.append('inpaint')
|
| 488 |
-
if current_tab == 'ip' or \
|
| 489 |
-
mixing_image_prompt_and_vary_upscale or \
|
| 490 |
-
mixing_image_prompt_and_inpaint:
|
| 491 |
-
goals.append('cn')
|
| 492 |
-
progressbar(async_task, 1, 'Downloading control models ...')
|
| 493 |
-
if len(cn_tasks[flags.cn_canny]) > 0:
|
| 494 |
-
controlnet_canny_path = config.downloading_controlnet_canny()
|
| 495 |
-
if len(cn_tasks[flags.cn_cpds]) > 0:
|
| 496 |
-
controlnet_cpds_path = config.downloading_controlnet_cpds()
|
| 497 |
-
if len(cn_tasks[flags.cn_ip]) > 0:
|
| 498 |
-
clip_vision_path, ip_negative_path, ip_adapter_path = config.downloading_ip_adapters('ip')
|
| 499 |
-
if len(cn_tasks[flags.cn_ip_face]) > 0:
|
| 500 |
-
clip_vision_path, ip_negative_path, ip_adapter_face_path = config.downloading_ip_adapters(
|
| 501 |
-
'face')
|
| 502 |
-
progressbar(async_task, 1, 'Loading control models ...')
|
| 503 |
-
|
| 504 |
-
# Load or unload CNs
|
| 505 |
-
pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
|
| 506 |
-
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
|
| 507 |
-
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
|
| 508 |
-
|
| 509 |
-
if overwrite_step > 0:
|
| 510 |
-
steps = overwrite_step
|
| 511 |
-
|
| 512 |
-
switch = int(round(steps * refiner_switch))
|
| 513 |
-
|
| 514 |
-
if overwrite_switch > 0:
|
| 515 |
-
switch = overwrite_switch
|
| 516 |
-
|
| 517 |
-
if overwrite_width > 0:
|
| 518 |
-
width = overwrite_width
|
| 519 |
-
|
| 520 |
-
if overwrite_height > 0:
|
| 521 |
-
height = overwrite_height
|
| 522 |
-
|
| 523 |
-
logger.std_info(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
|
| 524 |
-
logger.std_info(f'[Parameters] Steps = {steps} - {switch}')
|
| 525 |
-
|
| 526 |
-
progressbar(async_task, 1, 'Initializing ...')
|
| 527 |
-
|
| 528 |
-
if not skip_prompt_processing:
|
| 529 |
-
|
| 530 |
-
prompts = remove_empty_str([safe_str(p) for p in prompt.splitlines()], default='')
|
| 531 |
-
negative_prompts = remove_empty_str([safe_str(p) for p in negative_prompt.splitlines()], default='')
|
| 532 |
-
|
| 533 |
-
prompt = prompts[0]
|
| 534 |
-
negative_prompt = negative_prompts[0]
|
| 535 |
-
|
| 536 |
-
if prompt == '':
|
| 537 |
-
# disable expansion when empty since it is not meaningful and influences image prompt
|
| 538 |
-
use_expansion = False
|
| 539 |
-
|
| 540 |
-
extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
|
| 541 |
-
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
| 542 |
-
|
| 543 |
-
progressbar(async_task, 3, 'Loading models ...')
|
| 544 |
-
lora_filenames = remove_performance_lora(config.lora_filenames, performance_selection)
|
| 545 |
-
loras, prompt = parse_lora_references_from_prompt(prompt, loras, config.default_max_lora_number, lora_filenames=lora_filenames)
|
| 546 |
-
loras += performance_loras
|
| 547 |
-
|
| 548 |
-
pipeline.refresh_everything(
|
| 549 |
-
refiner_model_name=refiner_model_name,
|
| 550 |
-
base_model_name=base_model_name,
|
| 551 |
-
loras=loras,
|
| 552 |
-
base_model_additional_loras=base_model_additional_loras,
|
| 553 |
-
use_synthetic_refiner=use_synthetic_refiner)
|
| 554 |
-
|
| 555 |
-
pipeline.set_clip_skip(clip_skip)
|
| 556 |
-
|
| 557 |
-
progressbar(async_task, 3, 'Processing prompts ...')
|
| 558 |
-
tasks = []
|
| 559 |
-
|
| 560 |
-
for i in range(image_number):
|
| 561 |
-
if disable_seed_increment:
|
| 562 |
-
task_seed = seed % (constants.MAX_SEED + 1)
|
| 563 |
-
else:
|
| 564 |
-
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
|
| 565 |
-
|
| 566 |
-
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
|
| 567 |
-
task_prompt = apply_wildcards(prompt, task_rng, i, read_wildcards_in_order)
|
| 568 |
-
task_prompt = apply_arrays(task_prompt, i)
|
| 569 |
-
task_negative_prompt = apply_wildcards(negative_prompt, task_rng, i, read_wildcards_in_order)
|
| 570 |
-
task_extra_positive_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
| 571 |
-
extra_positive_prompts]
|
| 572 |
-
task_extra_negative_prompts = [apply_wildcards(pmt, task_rng, i, read_wildcards_in_order) for pmt in
|
| 573 |
-
extra_negative_prompts]
|
| 574 |
-
|
| 575 |
-
positive_basic_workloads = []
|
| 576 |
-
negative_basic_workloads = []
|
| 577 |
-
|
| 578 |
-
task_styles = style_selections.copy()
|
| 579 |
-
if use_style:
|
| 580 |
-
for index, style in enumerate(task_styles):
|
| 581 |
-
if style == random_style_name:
|
| 582 |
-
style = get_random_style(task_rng)
|
| 583 |
-
task_styles[index] = style
|
| 584 |
-
p, n = apply_style(style, positive=task_prompt)
|
| 585 |
-
positive_basic_workloads = positive_basic_workloads + p
|
| 586 |
-
negative_basic_workloads = negative_basic_workloads + n
|
| 587 |
-
else:
|
| 588 |
-
positive_basic_workloads.append(task_prompt)
|
| 589 |
-
|
| 590 |
-
negative_basic_workloads.append(task_negative_prompt) # Always use independent workload for negative.
|
| 591 |
-
|
| 592 |
-
positive_basic_workloads = positive_basic_workloads + task_extra_positive_prompts
|
| 593 |
-
negative_basic_workloads = negative_basic_workloads + task_extra_negative_prompts
|
| 594 |
-
|
| 595 |
-
positive_basic_workloads = remove_empty_str(positive_basic_workloads, default=task_prompt)
|
| 596 |
-
negative_basic_workloads = remove_empty_str(negative_basic_workloads, default=task_negative_prompt)
|
| 597 |
-
|
| 598 |
-
tasks.append(dict(
|
| 599 |
-
task_seed=task_seed,
|
| 600 |
-
task_prompt=task_prompt,
|
| 601 |
-
task_negative_prompt=task_negative_prompt,
|
| 602 |
-
positive=positive_basic_workloads,
|
| 603 |
-
negative=negative_basic_workloads,
|
| 604 |
-
expansion='',
|
| 605 |
-
c=None,
|
| 606 |
-
uc=None,
|
| 607 |
-
positive_top_k=len(positive_basic_workloads),
|
| 608 |
-
negative_top_k=len(negative_basic_workloads),
|
| 609 |
-
log_positive_prompt='\n'.join([task_prompt] + task_extra_positive_prompts),
|
| 610 |
-
log_negative_prompt='\n'.join([task_negative_prompt] + task_extra_negative_prompts),
|
| 611 |
-
styles=task_styles
|
| 612 |
-
))
|
| 613 |
-
|
| 614 |
-
if use_expansion:
|
| 615 |
-
for i, t in enumerate(tasks):
|
| 616 |
-
progressbar(async_task, 4, f'Preparing Fooocus text #{i + 1} ...')
|
| 617 |
-
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
|
| 618 |
-
logger.std_info(f'[Prompt Expansion] {expansion}')
|
| 619 |
-
t['expansion'] = expansion
|
| 620 |
-
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
|
| 621 |
-
|
| 622 |
-
for i, t in enumerate(tasks):
|
| 623 |
-
progressbar(async_task, 5, f'Encoding positive #{i + 1} ...')
|
| 624 |
-
t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
|
| 625 |
-
|
| 626 |
-
for i, t in enumerate(tasks):
|
| 627 |
-
if abs(float(cfg_scale) - 1.0) < 1e-4:
|
| 628 |
-
t['uc'] = pipeline.clone_cond(t['c'])
|
| 629 |
-
else:
|
| 630 |
-
progressbar(async_task, 6, f'Encoding negative #{i + 1} ...')
|
| 631 |
-
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
|
| 632 |
-
|
| 633 |
-
if len(goals) > 0:
|
| 634 |
-
progressbar(async_task, 7, 'Image processing ...')
|
| 635 |
-
|
| 636 |
-
if 'vary' in goals:
|
| 637 |
-
if 'subtle' in uov_method:
|
| 638 |
-
denoising_strength = 0.5
|
| 639 |
-
if 'strong' in uov_method:
|
| 640 |
-
denoising_strength = 0.85
|
| 641 |
-
if overwrite_vary_strength > 0:
|
| 642 |
-
denoising_strength = overwrite_vary_strength
|
| 643 |
-
|
| 644 |
-
shape_ceil = get_image_shape_ceil(uov_input_image)
|
| 645 |
-
if shape_ceil < 1024:
|
| 646 |
-
logger.std_warn('[Vary] Image is resized because it is too small.')
|
| 647 |
-
shape_ceil = 1024
|
| 648 |
-
elif shape_ceil > 2048:
|
| 649 |
-
logger.std_warn('[Vary] Image is resized because it is too big.')
|
| 650 |
-
shape_ceil = 2048
|
| 651 |
-
|
| 652 |
-
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
|
| 653 |
-
|
| 654 |
-
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
| 655 |
-
progressbar(async_task, 8, 'VAE encoding ...')
|
| 656 |
-
|
| 657 |
-
candidate_vae, _ = pipeline.get_candidate_vae(
|
| 658 |
-
steps=steps,
|
| 659 |
-
switch=switch,
|
| 660 |
-
denoise=denoising_strength,
|
| 661 |
-
refiner_swap_method=refiner_swap_method
|
| 662 |
-
)
|
| 663 |
-
|
| 664 |
-
initial_latent = core.encode_vae(vae=candidate_vae, pixels=initial_pixels)
|
| 665 |
-
B, C, H, W = initial_latent['samples'].shape
|
| 666 |
-
width = W * 8
|
| 667 |
-
height = H * 8
|
| 668 |
-
logger.std_info(f'[Vary] Final resolution is {str((height, width))}.')
|
| 669 |
-
|
| 670 |
-
if 'upscale' in goals:
|
| 671 |
-
H, W, C = uov_input_image.shape
|
| 672 |
-
progressbar(async_task, 9, f'Upscaling image from {str((H, W))} ...')
|
| 673 |
-
uov_input_image = perform_upscale(uov_input_image)
|
| 674 |
-
logger.std_info('[Upscale] Image upscale.')
|
| 675 |
-
|
| 676 |
-
if upscale_value is not None and upscale_value > 1.0:
|
| 677 |
-
f = upscale_value
|
| 678 |
-
else:
|
| 679 |
-
if '1.5x' in uov_method:
|
| 680 |
-
f = 1.5
|
| 681 |
-
elif '2x' in uov_method:
|
| 682 |
-
f = 2.0
|
| 683 |
-
else:
|
| 684 |
-
f = 1.0
|
| 685 |
-
|
| 686 |
-
shape_ceil = get_shape_ceil(H * f, W * f)
|
| 687 |
-
|
| 688 |
-
if shape_ceil < 1024:
|
| 689 |
-
logger.std_info('[Upscale] Image is resized because it is too small.')
|
| 690 |
-
uov_input_image = set_image_shape_ceil(uov_input_image, 1024)
|
| 691 |
-
shape_ceil = 1024
|
| 692 |
-
else:
|
| 693 |
-
uov_input_image = resample_image(uov_input_image, width=W * f, height=H * f)
|
| 694 |
-
|
| 695 |
-
image_is_super_large = shape_ceil > 2800
|
| 696 |
-
|
| 697 |
-
if 'fast' in uov_method:
|
| 698 |
-
direct_return = True
|
| 699 |
-
elif image_is_super_large:
|
| 700 |
-
logger.std_info('[Upscale] Image is too large. Directly returned the SR image. '
|
| 701 |
-
'Usually directly return SR image at 4K resolution '
|
| 702 |
-
'yields better results than SDXL diffusion.')
|
| 703 |
-
direct_return = True
|
| 704 |
-
else:
|
| 705 |
-
direct_return = False
|
| 706 |
-
|
| 707 |
-
if direct_return:
|
| 708 |
-
# d = [('Upscale (Fast)', '2x')]
|
| 709 |
-
# log(uov_input_image, d, output_format=save_extension)
|
| 710 |
-
if config.default_black_out_nsfw or black_out_nsfw:
|
| 711 |
-
uov_input_image = default_censor(uov_input_image)
|
| 712 |
-
yield_result(async_task, uov_input_image, tasks, save_extension, False, False)
|
| 713 |
-
return
|
| 714 |
-
|
| 715 |
-
tiled = True
|
| 716 |
-
denoising_strength = 0.382
|
| 717 |
-
|
| 718 |
-
if overwrite_upscale_strength > 0:
|
| 719 |
-
denoising_strength = overwrite_upscale_strength
|
| 720 |
-
|
| 721 |
-
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
| 722 |
-
progressbar(async_task, 10, 'VAE encoding ...')
|
| 723 |
-
|
| 724 |
-
candidate_vae, _ = pipeline.get_candidate_vae(
|
| 725 |
-
steps=steps,
|
| 726 |
-
switch=switch,
|
| 727 |
-
denoise=denoising_strength,
|
| 728 |
-
refiner_swap_method=refiner_swap_method
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
initial_latent = core.encode_vae(
|
| 732 |
-
vae=candidate_vae,
|
| 733 |
-
pixels=initial_pixels, tiled=True)
|
| 734 |
-
B, C, H, W = initial_latent['samples'].shape
|
| 735 |
-
width = W * 8
|
| 736 |
-
height = H * 8
|
| 737 |
-
logger.std_info(f'[Upscale] Final resolution is {str((height, width))}.')
|
| 738 |
-
|
| 739 |
-
if 'inpaint' in goals:
|
| 740 |
-
if len(outpaint_selections) > 0:
|
| 741 |
-
H, W, C = inpaint_image.shape
|
| 742 |
-
if 'top' in outpaint_selections:
|
| 743 |
-
distance_top = int(H * 0.3)
|
| 744 |
-
if outpaint_distance_top > 0:
|
| 745 |
-
distance_top = outpaint_distance_top
|
| 746 |
-
|
| 747 |
-
inpaint_image = np.pad(inpaint_image, [[distance_top, 0], [0, 0], [0, 0]], mode='edge')
|
| 748 |
-
inpaint_mask = np.pad(inpaint_mask, [[distance_top, 0], [0, 0]], mode='constant',
|
| 749 |
-
constant_values=255)
|
| 750 |
-
|
| 751 |
-
if 'bottom' in outpaint_selections:
|
| 752 |
-
distance_bottom = int(H * 0.3)
|
| 753 |
-
if outpaint_distance_bottom > 0:
|
| 754 |
-
distance_bottom = outpaint_distance_bottom
|
| 755 |
-
|
| 756 |
-
inpaint_image = np.pad(inpaint_image, [[0, distance_bottom], [0, 0], [0, 0]], mode='edge')
|
| 757 |
-
inpaint_mask = np.pad(inpaint_mask, [[0, distance_bottom], [0, 0]], mode='constant',
|
| 758 |
-
constant_values=255)
|
| 759 |
-
|
| 760 |
-
H, W, C = inpaint_image.shape
|
| 761 |
-
if 'left' in outpaint_selections:
|
| 762 |
-
distance_left = int(W * 0.3)
|
| 763 |
-
if outpaint_distance_left > 0:
|
| 764 |
-
distance_left = outpaint_distance_left
|
| 765 |
-
|
| 766 |
-
inpaint_image = np.pad(inpaint_image, [[0, 0], [distance_left, 0], [0, 0]], mode='edge')
|
| 767 |
-
inpaint_mask = np.pad(inpaint_mask, [[0, 0], [distance_left, 0]], mode='constant',
|
| 768 |
-
constant_values=255)
|
| 769 |
-
|
| 770 |
-
if 'right' in outpaint_selections:
|
| 771 |
-
distance_right = int(W * 0.3)
|
| 772 |
-
if outpaint_distance_right > 0:
|
| 773 |
-
distance_right = outpaint_distance_right
|
| 774 |
-
|
| 775 |
-
inpaint_image = np.pad(inpaint_image, [[0, 0], [0, distance_right], [0, 0]], mode='edge')
|
| 776 |
-
inpaint_mask = np.pad(inpaint_mask, [[0, 0], [0, distance_right]], mode='constant',
|
| 777 |
-
constant_values=255)
|
| 778 |
-
|
| 779 |
-
inpaint_image = np.ascontiguousarray(inpaint_image.copy())
|
| 780 |
-
inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
|
| 781 |
-
inpaint_strength = 1.0
|
| 782 |
-
inpaint_respective_field = 1.0
|
| 783 |
-
|
| 784 |
-
denoising_strength = inpaint_strength
|
| 785 |
-
|
| 786 |
-
inpaint_worker.current_task = inpaint_worker.InpaintWorker(
|
| 787 |
-
image=inpaint_image,
|
| 788 |
-
mask=inpaint_mask,
|
| 789 |
-
use_fill=denoising_strength > 0.99,
|
| 790 |
-
k=inpaint_respective_field
|
| 791 |
-
)
|
| 792 |
-
|
| 793 |
-
if debugging_inpaint_preprocessor:
|
| 794 |
-
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(), tasks,
|
| 795 |
-
black_out_nsfw)
|
| 796 |
-
return
|
| 797 |
-
|
| 798 |
-
progressbar(async_task, 11, 'VAE Inpaint encoding ...')
|
| 799 |
-
|
| 800 |
-
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
|
| 801 |
-
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
|
| 802 |
-
inpaint_pixel_mask = core.numpy_to_pytorch(inpaint_worker.current_task.interested_mask)
|
| 803 |
-
|
| 804 |
-
candidate_vae, candidate_vae_swap = pipeline.get_candidate_vae(
|
| 805 |
-
steps=steps,
|
| 806 |
-
switch=switch,
|
| 807 |
-
denoise=denoising_strength,
|
| 808 |
-
refiner_swap_method=refiner_swap_method
|
| 809 |
-
)
|
| 810 |
-
|
| 811 |
-
latent_inpaint, latent_mask = core.encode_vae_inpaint(
|
| 812 |
-
mask=inpaint_pixel_mask,
|
| 813 |
-
vae=candidate_vae,
|
| 814 |
-
pixels=inpaint_pixel_image)
|
| 815 |
-
|
| 816 |
-
latent_swap = None
|
| 817 |
-
if candidate_vae_swap is not None:
|
| 818 |
-
progressbar(async_task, 12, 'VAE SD15 encoding ...')
|
| 819 |
-
latent_swap = core.encode_vae(
|
| 820 |
-
vae=candidate_vae_swap,
|
| 821 |
-
pixels=inpaint_pixel_fill)['samples']
|
| 822 |
-
|
| 823 |
-
progressbar(async_task, 13, 'VAE encoding ...')
|
| 824 |
-
latent_fill = core.encode_vae(
|
| 825 |
-
vae=candidate_vae,
|
| 826 |
-
pixels=inpaint_pixel_fill)['samples']
|
| 827 |
-
|
| 828 |
-
inpaint_worker.current_task.load_latent(
|
| 829 |
-
latent_fill=latent_fill, latent_mask=latent_mask, latent_swap=latent_swap)
|
| 830 |
-
|
| 831 |
-
if inpaint_parameterized:
|
| 832 |
-
pipeline.final_unet = inpaint_worker.current_task.patch(
|
| 833 |
-
inpaint_head_model_path=inpaint_head_model_path,
|
| 834 |
-
inpaint_latent=latent_inpaint,
|
| 835 |
-
inpaint_latent_mask=latent_mask,
|
| 836 |
-
model=pipeline.final_unet
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
if not inpaint_disable_initial_latent:
|
| 840 |
-
initial_latent = {'samples': latent_fill}
|
| 841 |
-
|
| 842 |
-
B, C, H, W = latent_fill.shape
|
| 843 |
-
height, width = H * 8, W * 8
|
| 844 |
-
final_height, final_width = inpaint_worker.current_task.image.shape[:2]
|
| 845 |
-
logger.std_info(f'[Inpaint] Final resolution is {str((final_height, final_width))}, latent is {str((height, width))}.')
|
| 846 |
-
|
| 847 |
-
if 'cn' in goals:
|
| 848 |
-
for task in cn_tasks[flags.cn_canny]:
|
| 849 |
-
cn_img, cn_stop, cn_weight = task
|
| 850 |
-
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
| 851 |
-
|
| 852 |
-
if not skipping_cn_preprocessor:
|
| 853 |
-
cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
|
| 854 |
-
|
| 855 |
-
cn_img = HWC3(cn_img)
|
| 856 |
-
task[0] = core.numpy_to_pytorch(cn_img)
|
| 857 |
-
if debugging_cn_preprocessor:
|
| 858 |
-
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
| 859 |
-
return
|
| 860 |
-
for task in cn_tasks[flags.cn_cpds]:
|
| 861 |
-
cn_img, cn_stop, cn_weight = task
|
| 862 |
-
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
| 863 |
-
|
| 864 |
-
if not skipping_cn_preprocessor:
|
| 865 |
-
cn_img = preprocessors.cpds(cn_img)
|
| 866 |
-
|
| 867 |
-
cn_img = HWC3(cn_img)
|
| 868 |
-
task[0] = core.numpy_to_pytorch(cn_img)
|
| 869 |
-
if debugging_cn_preprocessor:
|
| 870 |
-
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
| 871 |
-
return
|
| 872 |
-
for task in cn_tasks[flags.cn_ip]:
|
| 873 |
-
cn_img, cn_stop, cn_weight = task
|
| 874 |
-
cn_img = HWC3(cn_img)
|
| 875 |
-
|
| 876 |
-
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
| 877 |
-
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
| 878 |
-
|
| 879 |
-
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
|
| 880 |
-
if debugging_cn_preprocessor:
|
| 881 |
-
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
| 882 |
-
return
|
| 883 |
-
for task in cn_tasks[flags.cn_ip_face]:
|
| 884 |
-
cn_img, cn_stop, cn_weight = task
|
| 885 |
-
cn_img = HWC3(cn_img)
|
| 886 |
-
|
| 887 |
-
if not skipping_cn_preprocessor:
|
| 888 |
-
cn_img = face_crop.crop_image(cn_img)
|
| 889 |
-
|
| 890 |
-
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
| 891 |
-
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
| 892 |
-
|
| 893 |
-
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
|
| 894 |
-
if debugging_cn_preprocessor:
|
| 895 |
-
yield_result(async_task, cn_img, tasks, save_extension, black_out_nsfw)
|
| 896 |
-
return
|
| 897 |
-
|
| 898 |
-
all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
|
| 899 |
-
|
| 900 |
-
if len(all_ip_tasks) > 0:
|
| 901 |
-
pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
|
| 902 |
-
|
| 903 |
-
if freeu_enabled:
|
| 904 |
-
logger.std_info('[Fooocus] FreeU is enabled!')
|
| 905 |
-
pipeline.final_unet = core.apply_freeu(
|
| 906 |
-
pipeline.final_unet,
|
| 907 |
-
freeu_b1,
|
| 908 |
-
freeu_b2,
|
| 909 |
-
freeu_s1,
|
| 910 |
-
freeu_s2
|
| 911 |
-
)
|
| 912 |
-
|
| 913 |
-
all_steps = steps * image_number
|
| 914 |
-
|
| 915 |
-
logger.std_info(f'[Parameters] Denoising Strength = {denoising_strength}')
|
| 916 |
-
|
| 917 |
-
if isinstance(initial_latent, dict) and 'samples' in initial_latent:
|
| 918 |
-
log_shape = initial_latent['samples'].shape
|
| 919 |
-
else:
|
| 920 |
-
log_shape = f'Image Space {(height, width)}'
|
| 921 |
-
|
| 922 |
-
logger.std_info(f'[Parameters] Initial Latent shape: {log_shape}')
|
| 923 |
-
|
| 924 |
-
preparation_time = time.perf_counter() - execution_start_time
|
| 925 |
-
logger.std_info(f'[Fooocus] Preparation time: {preparation_time:.2f} seconds')
|
| 926 |
-
|
| 927 |
-
final_sampler_name = sampler_name
|
| 928 |
-
final_scheduler_name = scheduler_name
|
| 929 |
-
|
| 930 |
-
if scheduler_name in ['lcm', 'tcd']:
|
| 931 |
-
final_scheduler_name = 'sgm_uniform'
|
| 932 |
-
|
| 933 |
-
def patch_discrete(unet):
|
| 934 |
-
return core.opModelSamplingDiscrete.patch(
|
| 935 |
-
pipeline.final_unet,
|
| 936 |
-
sampling=scheduler_name,
|
| 937 |
-
zsnr=False)[0]
|
| 938 |
-
|
| 939 |
-
if pipeline.final_unet is not None:
|
| 940 |
-
pipeline.final_unet = patch_discrete(pipeline.final_unet)
|
| 941 |
-
if pipeline.final_refiner_unet is not None:
|
| 942 |
-
pipeline.final_refiner_unet = patch_discrete(pipeline.final_refiner_unet)
|
| 943 |
-
logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
|
| 944 |
-
elif scheduler_name == 'edm_playground_v2.5':
|
| 945 |
-
final_scheduler_name = 'karras'
|
| 946 |
-
|
| 947 |
-
def patch_edm(unet):
|
| 948 |
-
return core.opModelSamplingContinuousEDM.patch(
|
| 949 |
-
unet,
|
| 950 |
-
sampling=scheduler_name,
|
| 951 |
-
sigma_max=120.0,
|
| 952 |
-
sigma_min=0.002)[0]
|
| 953 |
-
|
| 954 |
-
if pipeline.final_unet is not None:
|
| 955 |
-
pipeline.final_unet = patch_edm(pipeline.final_unet)
|
| 956 |
-
if pipeline.final_refiner_unet is not None:
|
| 957 |
-
pipeline.final_refiner_unet = patch_edm(pipeline.final_refiner_unet)
|
| 958 |
-
|
| 959 |
-
logger.std_info(f'[Fooocus] Using {scheduler_name} scheduler.')
|
| 960 |
-
|
| 961 |
-
outputs.append(['preview', (13, 'Moving model to GPU ...', None)])
|
| 962 |
-
|
| 963 |
-
def callback(step, x0, x, total_steps, y):
|
| 964 |
-
"""callback, used for progress and preview"""
|
| 965 |
-
done_steps = current_task_id * steps + step
|
| 966 |
-
outputs.append(['preview', (
|
| 967 |
-
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
|
| 968 |
-
f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
|
| 969 |
-
y)])
|
| 970 |
-
|
| 971 |
-
for current_task_id, task in enumerate(tasks):
|
| 972 |
-
execution_start_time = time.perf_counter()
|
| 973 |
-
|
| 974 |
-
try:
|
| 975 |
-
positive_cond, negative_cond = task['c'], task['uc']
|
| 976 |
-
|
| 977 |
-
if 'cn' in goals:
|
| 978 |
-
for cn_flag, cn_path in [
|
| 979 |
-
(flags.cn_canny, controlnet_canny_path),
|
| 980 |
-
(flags.cn_cpds, controlnet_cpds_path)
|
| 981 |
-
]:
|
| 982 |
-
for cn_img, cn_stop, cn_weight in cn_tasks[cn_flag]:
|
| 983 |
-
positive_cond, negative_cond = core.apply_controlnet(
|
| 984 |
-
positive_cond, negative_cond,
|
| 985 |
-
pipeline.loaded_ControlNets[cn_path], cn_img, cn_weight, 0, cn_stop)
|
| 986 |
-
|
| 987 |
-
imgs = pipeline.process_diffusion(
|
| 988 |
-
positive_cond=positive_cond,
|
| 989 |
-
negative_cond=negative_cond,
|
| 990 |
-
steps=steps,
|
| 991 |
-
switch=switch,
|
| 992 |
-
width=width,
|
| 993 |
-
height=height,
|
| 994 |
-
image_seed=task['task_seed'],
|
| 995 |
-
callback=callback,
|
| 996 |
-
sampler_name=final_sampler_name,
|
| 997 |
-
scheduler_name=final_scheduler_name,
|
| 998 |
-
latent=initial_latent,
|
| 999 |
-
denoise=denoising_strength,
|
| 1000 |
-
tiled=tiled,
|
| 1001 |
-
cfg_scale=cfg_scale,
|
| 1002 |
-
refiner_swap_method=refiner_swap_method,
|
| 1003 |
-
disable_preview=disable_preview
|
| 1004 |
-
)
|
| 1005 |
-
|
| 1006 |
-
del task['c'], task['uc'], positive_cond, negative_cond # Save memory
|
| 1007 |
-
|
| 1008 |
-
if inpaint_worker.current_task is not None:
|
| 1009 |
-
imgs = [inpaint_worker.current_task.post_process(x) for x in imgs]
|
| 1010 |
-
|
| 1011 |
-
# Fooocus async_worker.py code end
|
| 1012 |
-
|
| 1013 |
-
results += imgs
|
| 1014 |
-
except model_management.InterruptProcessingException as e:
|
| 1015 |
-
logger.std_warn("[Fooocus] User stopped")
|
| 1016 |
-
results = []
|
| 1017 |
-
results.append(ImageGenerationResult(
|
| 1018 |
-
im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.user_cancel))
|
| 1019 |
-
async_task.set_result(results, True, str(e))
|
| 1020 |
-
break
|
| 1021 |
-
except Exception as e:
|
| 1022 |
-
logger.std_error(f'[Fooocus] Process error: {e}')
|
| 1023 |
-
logging.exception(e)
|
| 1024 |
-
results = []
|
| 1025 |
-
results.append(ImageGenerationResult(
|
| 1026 |
-
im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.error))
|
| 1027 |
-
async_task.set_result(results, True, str(e))
|
| 1028 |
-
break
|
| 1029 |
-
|
| 1030 |
-
execution_time = time.perf_counter() - execution_start_time
|
| 1031 |
-
logger.std_info(f'[Fooocus] Generating and saving time: {execution_time:.2f} seconds')
|
| 1032 |
-
|
| 1033 |
-
if async_task.finish_with_error:
|
| 1034 |
-
worker_queue.finish_task(async_task.job_id)
|
| 1035 |
-
return async_task.task_result
|
| 1036 |
-
yield_result(None, results, tasks, save_extension, black_out_nsfw)
|
| 1037 |
-
return
|
| 1038 |
-
except Exception as e:
|
| 1039 |
-
logger.std_error(f'[Fooocus] Worker error: {e}')
|
| 1040 |
-
|
| 1041 |
-
if not async_task.is_finished:
|
| 1042 |
-
async_task.set_result([], True, str(e))
|
| 1043 |
-
worker_queue.finish_task(async_task.job_id)
|
| 1044 |
-
logger.std_info(f"[Task Queue] Finish task with error, job_id={async_task.job_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predict.py
DELETED
|
@@ -1,316 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Prediction interface for Cog ⚙️
|
| 3 |
-
https://github.com/replicate/cog/blob/main/docs/python.md
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import copy
|
| 7 |
-
import os
|
| 8 |
-
from typing import List
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
from PIL import Image
|
| 12 |
-
from cog import BasePredictor, BaseModel, Input, Path
|
| 13 |
-
from fooocusapi.utils.lora_manager import LoraManager
|
| 14 |
-
from fooocusapi.utils.file_utils import output_dir
|
| 15 |
-
from fooocusapi.models.common.task import GenerationFinishReason
|
| 16 |
-
from fooocusapi.configs.default import (
|
| 17 |
-
available_aspect_ratios,
|
| 18 |
-
uov_methods,
|
| 19 |
-
outpaint_expansions,
|
| 20 |
-
default_styles,
|
| 21 |
-
default_base_model_name,
|
| 22 |
-
default_refiner_model_name,
|
| 23 |
-
default_loras,
|
| 24 |
-
default_refiner_switch,
|
| 25 |
-
default_cfg_scale,
|
| 26 |
-
default_prompt_negative
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
from fooocusapi.parameters import ImageGenerationParams
|
| 30 |
-
from fooocusapi.task_queue import TaskType
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class Output(BaseModel):
|
| 34 |
-
"""
|
| 35 |
-
Output model
|
| 36 |
-
"""
|
| 37 |
-
seeds: List[str]
|
| 38 |
-
paths: List[Path]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class Predictor(BasePredictor):
|
| 42 |
-
"""Predictor"""
|
| 43 |
-
def setup(self) -> None:
|
| 44 |
-
"""
|
| 45 |
-
Load the model into memory to make running multiple predictions efficient
|
| 46 |
-
"""
|
| 47 |
-
from main import pre_setup
|
| 48 |
-
pre_setup()
|
| 49 |
-
|
| 50 |
-
def predict(
|
| 51 |
-
self,
|
| 52 |
-
prompt: str = Input(
|
| 53 |
-
default='',
|
| 54 |
-
description="Prompt for image generation"),
|
| 55 |
-
negative_prompt: str = Input(
|
| 56 |
-
default=default_prompt_negative,
|
| 57 |
-
description="Negative prompt for image generation"),
|
| 58 |
-
style_selections: str = Input(
|
| 59 |
-
default=','.join(default_styles),
|
| 60 |
-
description="Fooocus styles applied for image generation, separated by comma"),
|
| 61 |
-
performance_selection: str = Input(
|
| 62 |
-
default='Speed',
|
| 63 |
-
choices=['Speed', 'Quality', 'Extreme Speed', 'Lightning'],
|
| 64 |
-
description="Performance selection"),
|
| 65 |
-
aspect_ratios_selection: str = Input(
|
| 66 |
-
default='1152*896',
|
| 67 |
-
choices=available_aspect_ratios,
|
| 68 |
-
description="The generated image's size"),
|
| 69 |
-
image_number: int = Input(
|
| 70 |
-
default=1,
|
| 71 |
-
ge=1, le=8,
|
| 72 |
-
description="How many image to generate"),
|
| 73 |
-
image_seed: int = Input(
|
| 74 |
-
default=-1,
|
| 75 |
-
description="Seed to generate image, -1 for random"),
|
| 76 |
-
use_default_loras: bool = Input(
|
| 77 |
-
default=True,
|
| 78 |
-
description="Use default LoRAs"),
|
| 79 |
-
loras_custom_urls: str = Input(
|
| 80 |
-
default="",
|
| 81 |
-
description="Custom LoRAs URLs in the format 'url,weight' provide multiple separated by ; (example 'url1,0.3;url2,0.1')"),
|
| 82 |
-
sharpness: float = Input(
|
| 83 |
-
default=2.0,
|
| 84 |
-
ge=0.0, le=30.0),
|
| 85 |
-
guidance_scale: float = Input(
|
| 86 |
-
default=default_cfg_scale,
|
| 87 |
-
ge=1.0, le=30.0),
|
| 88 |
-
refiner_switch: float = Input(
|
| 89 |
-
default=default_refiner_switch,
|
| 90 |
-
ge=0.1, le=1.0),
|
| 91 |
-
uov_input_image: Path = Input(
|
| 92 |
-
default=None,
|
| 93 |
-
description="Input image for upscale or variation, keep None for not upscale or variation"),
|
| 94 |
-
uov_method: str = Input(
|
| 95 |
-
default='Disabled',
|
| 96 |
-
choices=uov_methods),
|
| 97 |
-
uov_upscale_value: float = Input(
|
| 98 |
-
default=0,
|
| 99 |
-
description="Only when Upscale (Custom)"),
|
| 100 |
-
inpaint_additional_prompt: str = Input(
|
| 101 |
-
default='',
|
| 102 |
-
description="Prompt for image generation"),
|
| 103 |
-
inpaint_input_image: Path = Input(
|
| 104 |
-
default=None,
|
| 105 |
-
description="Input image for inpaint or outpaint, keep None for not inpaint or outpaint. Please noticed, `uov_input_image` has bigger priority is not None."),
|
| 106 |
-
inpaint_input_mask: Path = Input(
|
| 107 |
-
default=None,
|
| 108 |
-
description="Input mask for inpaint"),
|
| 109 |
-
outpaint_selections: str = Input(
|
| 110 |
-
default='',
|
| 111 |
-
description="Outpaint expansion selections, literal 'Left', 'Right', 'Top', 'Bottom' separated by comma"),
|
| 112 |
-
outpaint_distance_left: int = Input(
|
| 113 |
-
default=0,
|
| 114 |
-
description="Outpaint expansion distance from Left of the image"),
|
| 115 |
-
outpaint_distance_top: int = Input(
|
| 116 |
-
default=0,
|
| 117 |
-
description="Outpaint expansion distance from Top of the image"),
|
| 118 |
-
outpaint_distance_right: int = Input(
|
| 119 |
-
default=0,
|
| 120 |
-
description="Outpaint expansion distance from Right of the image"),
|
| 121 |
-
outpaint_distance_bottom: int = Input(
|
| 122 |
-
default=0,
|
| 123 |
-
description="Outpaint expansion distance from Bottom of the image"),
|
| 124 |
-
cn_img1: Path = Input(
|
| 125 |
-
default=None,
|
| 126 |
-
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
| 127 |
-
cn_stop1: float = Input(
|
| 128 |
-
default=None,
|
| 129 |
-
ge=0, le=1,
|
| 130 |
-
description="Stop at for image prompt, None for default value"),
|
| 131 |
-
cn_weight1: float = Input(
|
| 132 |
-
default=None,
|
| 133 |
-
ge=0, le=2,
|
| 134 |
-
description="Weight for image prompt, None for default value"),
|
| 135 |
-
cn_type1: str = Input(
|
| 136 |
-
default='ImagePrompt',
|
| 137 |
-
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
| 138 |
-
description="ControlNet type for image prompt"),
|
| 139 |
-
cn_img2: Path = Input(
|
| 140 |
-
default=None,
|
| 141 |
-
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
| 142 |
-
cn_stop2: float = Input(
|
| 143 |
-
default=None,
|
| 144 |
-
ge=0, le=1,
|
| 145 |
-
description="Stop at for image prompt, None for default value"),
|
| 146 |
-
cn_weight2: float = Input(
|
| 147 |
-
default=None,
|
| 148 |
-
ge=0, le=2,
|
| 149 |
-
description="Weight for image prompt, None for default value"),
|
| 150 |
-
cn_type2: str = Input(
|
| 151 |
-
default='ImagePrompt',
|
| 152 |
-
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
| 153 |
-
description="ControlNet type for image prompt"),
|
| 154 |
-
cn_img3: Path = Input(
|
| 155 |
-
default=None,
|
| 156 |
-
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
| 157 |
-
cn_stop3: float = Input(
|
| 158 |
-
default=None,
|
| 159 |
-
ge=0, le=1,
|
| 160 |
-
description="Stop at for image prompt, None for default value"),
|
| 161 |
-
cn_weight3: float = Input(
|
| 162 |
-
default=None,
|
| 163 |
-
ge=0, le=2,
|
| 164 |
-
description="Weight for image prompt, None for default value"),
|
| 165 |
-
cn_type3: str = Input(
|
| 166 |
-
default='ImagePrompt',
|
| 167 |
-
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
| 168 |
-
description="ControlNet type for image prompt"),
|
| 169 |
-
cn_img4: Path = Input(
|
| 170 |
-
default=None,
|
| 171 |
-
description="Input image for image prompt. If all cn_img[n] are None, image prompt will not applied."),
|
| 172 |
-
cn_stop4: float = Input(
|
| 173 |
-
default=None,
|
| 174 |
-
ge=0, le=1,
|
| 175 |
-
description="Stop at for image prompt, None for default value"),
|
| 176 |
-
cn_weight4: float = Input(
|
| 177 |
-
default=None,
|
| 178 |
-
ge=0, le=2,
|
| 179 |
-
description="Weight for image prompt, None for default value"),
|
| 180 |
-
cn_type4: str = Input(
|
| 181 |
-
default='ImagePrompt',
|
| 182 |
-
choices=['ImagePrompt', 'FaceSwap', 'PyraCanny', 'CPDS'],
|
| 183 |
-
description="ControlNet type for image prompt")
|
| 184 |
-
) -> Output:
|
| 185 |
-
"""Run a single prediction on the model"""
|
| 186 |
-
from modules import flags
|
| 187 |
-
from modules.sdxl_styles import legal_style_names
|
| 188 |
-
from fooocusapi.worker import blocking_get_task_result, worker_queue
|
| 189 |
-
|
| 190 |
-
base_model_name = default_base_model_name
|
| 191 |
-
refiner_model_name = default_refiner_model_name
|
| 192 |
-
|
| 193 |
-
lora_manager = LoraManager()
|
| 194 |
-
|
| 195 |
-
# Use default loras if selected
|
| 196 |
-
loras = copy.copy(default_loras) if use_default_loras else []
|
| 197 |
-
|
| 198 |
-
# add custom user loras if provided
|
| 199 |
-
if loras_custom_urls:
|
| 200 |
-
urls = [url.strip() for url in loras_custom_urls.split(';')]
|
| 201 |
-
|
| 202 |
-
loras_with_weights = [url.split(',') for url in urls]
|
| 203 |
-
|
| 204 |
-
custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights])
|
| 205 |
-
custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in
|
| 206 |
-
zip(custom_lora_paths, loras_with_weights)]
|
| 207 |
-
|
| 208 |
-
loras.extend(custom_loras)
|
| 209 |
-
|
| 210 |
-
style_selections_arr = []
|
| 211 |
-
for s in style_selections.strip().split(','):
|
| 212 |
-
style = s.strip()
|
| 213 |
-
if style in legal_style_names:
|
| 214 |
-
style_selections_arr.append(style)
|
| 215 |
-
|
| 216 |
-
if uov_input_image is not None:
|
| 217 |
-
im = Image.open(str(uov_input_image))
|
| 218 |
-
uov_input_image = np.array(im)
|
| 219 |
-
|
| 220 |
-
inpaint_input_image_dict = None
|
| 221 |
-
if inpaint_input_image is not None:
|
| 222 |
-
im = Image.open(str(inpaint_input_image))
|
| 223 |
-
inpaint_input_image = np.array(im)
|
| 224 |
-
|
| 225 |
-
if inpaint_input_mask is not None:
|
| 226 |
-
im = Image.open(str(inpaint_input_mask))
|
| 227 |
-
inpaint_input_mask = np.array(im)
|
| 228 |
-
|
| 229 |
-
inpaint_input_image_dict = {
|
| 230 |
-
'image': inpaint_input_image,
|
| 231 |
-
'mask': inpaint_input_mask
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
outpaint_selections_arr = []
|
| 235 |
-
for e in outpaint_selections.strip().split(','):
|
| 236 |
-
expansion = e.strip()
|
| 237 |
-
if expansion in outpaint_expansions:
|
| 238 |
-
outpaint_selections_arr.append(expansion)
|
| 239 |
-
|
| 240 |
-
image_prompts = []
|
| 241 |
-
image_prompt_config = [
|
| 242 |
-
(cn_img1, cn_stop1, cn_weight1, cn_type1),
|
| 243 |
-
(cn_img2, cn_stop2, cn_weight2, cn_type2),
|
| 244 |
-
(cn_img3, cn_stop3, cn_weight3, cn_type3),
|
| 245 |
-
(cn_img4, cn_stop4, cn_weight4, cn_type4)]
|
| 246 |
-
for config in image_prompt_config:
|
| 247 |
-
cn_img, cn_stop, cn_weight, cn_type = config
|
| 248 |
-
if cn_img is not None:
|
| 249 |
-
im = Image.open(str(cn_img))
|
| 250 |
-
cn_img = np.array(im)
|
| 251 |
-
if cn_stop is None:
|
| 252 |
-
cn_stop = flags.default_parameters[cn_type][0]
|
| 253 |
-
if cn_weight is None:
|
| 254 |
-
cn_weight = flags.default_parameters[cn_type][1]
|
| 255 |
-
image_prompts.append((cn_img, cn_stop, cn_weight, cn_type))
|
| 256 |
-
|
| 257 |
-
advanced_params = None
|
| 258 |
-
|
| 259 |
-
params = ImageGenerationParams(
|
| 260 |
-
prompt=prompt,
|
| 261 |
-
negative_prompt=negative_prompt,
|
| 262 |
-
style_selections=style_selections_arr,
|
| 263 |
-
performance_selection=performance_selection,
|
| 264 |
-
aspect_ratios_selection=aspect_ratios_selection,
|
| 265 |
-
image_number=image_number,
|
| 266 |
-
image_seed=image_seed,
|
| 267 |
-
sharpness=sharpness,
|
| 268 |
-
guidance_scale=guidance_scale,
|
| 269 |
-
base_model_name=base_model_name,
|
| 270 |
-
refiner_model_name=refiner_model_name,
|
| 271 |
-
refiner_switch=refiner_switch,
|
| 272 |
-
loras=loras,
|
| 273 |
-
uov_input_image=uov_input_image,
|
| 274 |
-
uov_method=uov_method,
|
| 275 |
-
upscale_value=uov_upscale_value,
|
| 276 |
-
outpaint_selections=outpaint_selections_arr,
|
| 277 |
-
inpaint_input_image=inpaint_input_image_dict,
|
| 278 |
-
image_prompts=image_prompts,
|
| 279 |
-
advanced_params=advanced_params,
|
| 280 |
-
inpaint_additional_prompt=inpaint_additional_prompt,
|
| 281 |
-
outpaint_distance_left=outpaint_distance_left,
|
| 282 |
-
outpaint_distance_top=outpaint_distance_top,
|
| 283 |
-
outpaint_distance_right=outpaint_distance_right,
|
| 284 |
-
outpaint_distance_bottom=outpaint_distance_bottom,
|
| 285 |
-
save_meta=True,
|
| 286 |
-
meta_scheme='fooocus',
|
| 287 |
-
save_extension='png',
|
| 288 |
-
save_name='',
|
| 289 |
-
require_base64=False,
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
print(f"[Predictor Predict] Params: {params.__dict__}")
|
| 293 |
-
|
| 294 |
-
async_task = worker_queue.add_task(
|
| 295 |
-
TaskType.text_2_img,
|
| 296 |
-
params)
|
| 297 |
-
|
| 298 |
-
if async_task is None:
|
| 299 |
-
print("[Task Queue] The task queue has reached limit")
|
| 300 |
-
raise Exception("The task queue has reached limit.")
|
| 301 |
-
|
| 302 |
-
results = blocking_get_task_result(async_task.job_id)
|
| 303 |
-
|
| 304 |
-
output_paths: List[Path] = []
|
| 305 |
-
output_seeds: List[str] = []
|
| 306 |
-
for r in results:
|
| 307 |
-
if r.finish_reason == GenerationFinishReason.success and r.im is not None:
|
| 308 |
-
output_seeds.append(r.seed)
|
| 309 |
-
output_paths.append(Path(os.path.join(output_dir, r.im)))
|
| 310 |
-
|
| 311 |
-
print(f"[Predictor Predict] Finished with {len(output_paths)} images")
|
| 312 |
-
|
| 313 |
-
if len(output_paths) == 0:
|
| 314 |
-
raise Exception("Process failed.")
|
| 315 |
-
|
| 316 |
-
return Output(seeds=output_seeds, paths=output_paths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Created By: ishwor subedi
|
| 3 |
-
Date: 2024-07-19
|
| 4 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/args_manager.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
import ldm_patched.modules.args_parser as args_parser
|
| 2 |
-
|
| 3 |
-
args_parser.parser.add_argument("--share", action='store_true', help="Set whether to share on Gradio.")
|
| 4 |
-
|
| 5 |
-
args_parser.parser.add_argument("--preset", type=str, default=None, help="Apply specified UI preset.")
|
| 6 |
-
args_parser.parser.add_argument("--disable-preset-selection", action='store_true',
|
| 7 |
-
help="Disables preset selection in Gradio.")
|
| 8 |
-
|
| 9 |
-
args_parser.parser.add_argument("--language", type=str, default='default',
|
| 10 |
-
help="Translate UI using json files in [language] folder. "
|
| 11 |
-
"For example, [--language example] will use [language/example.json] for translation.")
|
| 12 |
-
|
| 13 |
-
# For example, https://github.com/lllyasviel/Fooocus/issues/849
|
| 14 |
-
args_parser.parser.add_argument("--disable-offload-from-vram", action="store_true",
|
| 15 |
-
help="Force loading models to vram when the unload can be avoided. "
|
| 16 |
-
"Some Mac users may need this.")
|
| 17 |
-
|
| 18 |
-
args_parser.parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
| 19 |
-
args_parser.parser.add_argument("--disable-image-log", action='store_true',
|
| 20 |
-
help="Prevent writing images and logs to hard drive.")
|
| 21 |
-
|
| 22 |
-
args_parser.parser.add_argument("--disable-analytics", action='store_true',
|
| 23 |
-
help="Disables analytics for Gradio.")
|
| 24 |
-
|
| 25 |
-
args_parser.parser.add_argument("--disable-metadata", action='store_true',
|
| 26 |
-
help="Disables saving metadata to images.")
|
| 27 |
-
|
| 28 |
-
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
|
| 29 |
-
help="Disables downloading models for presets", default=False)
|
| 30 |
-
|
| 31 |
-
args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
|
| 32 |
-
help="Disables automatic description of uov images when prompt is empty", default=False)
|
| 33 |
-
|
| 34 |
-
args_parser.parser.add_argument("--always-download-new-model", action='store_true',
|
| 35 |
-
help="Always download newer models ", default=False)
|
| 36 |
-
|
| 37 |
-
args_parser.parser.set_defaults(
|
| 38 |
-
disable_cuda_malloc=True,
|
| 39 |
-
in_browser=True,
|
| 40 |
-
port=None
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
args_parser.args = args_parser.parser.parse_args()
|
| 44 |
-
|
| 45 |
-
# (Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
|
| 46 |
-
args_parser.args.always_offload_from_vram = not args_parser.args.disable_offload_from_vram
|
| 47 |
-
|
| 48 |
-
if args_parser.args.disable_analytics:
|
| 49 |
-
import os
|
| 50 |
-
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
| 51 |
-
|
| 52 |
-
if args_parser.args.disable_in_browser:
|
| 53 |
-
args_parser.args.in_browser = False
|
| 54 |
-
|
| 55 |
-
args = args_parser.args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/bert_config.json
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"architectures": [
|
| 3 |
-
"BertModel"
|
| 4 |
-
],
|
| 5 |
-
"attention_probs_dropout_prob": 0.1,
|
| 6 |
-
"hidden_act": "gelu",
|
| 7 |
-
"hidden_dropout_prob": 0.1,
|
| 8 |
-
"hidden_size": 768,
|
| 9 |
-
"initializer_range": 0.02,
|
| 10 |
-
"intermediate_size": 3072,
|
| 11 |
-
"layer_norm_eps": 1e-12,
|
| 12 |
-
"max_position_embeddings": 512,
|
| 13 |
-
"model_type": "bert",
|
| 14 |
-
"num_attention_heads": 12,
|
| 15 |
-
"num_hidden_layers": 12,
|
| 16 |
-
"pad_token_id": 0,
|
| 17 |
-
"type_vocab_size": 2,
|
| 18 |
-
"vocab_size": 30522,
|
| 19 |
-
"encoder_width": 768,
|
| 20 |
-
"add_cross_attention": true
|
| 21 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/caption_coco.yaml
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
image_root: '/export/share/datasets/vision/coco/images/'
|
| 2 |
-
ann_root: 'annotation'
|
| 3 |
-
coco_gt_root: 'annotation/coco_gt'
|
| 4 |
-
|
| 5 |
-
# set pretrained as a file path or an url
|
| 6 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
| 7 |
-
|
| 8 |
-
# size of vit model; base or large
|
| 9 |
-
vit: 'base'
|
| 10 |
-
vit_grad_ckpt: False
|
| 11 |
-
vit_ckpt_layer: 0
|
| 12 |
-
batch_size: 32
|
| 13 |
-
init_lr: 1e-5
|
| 14 |
-
|
| 15 |
-
# vit: 'large'
|
| 16 |
-
# vit_grad_ckpt: True
|
| 17 |
-
# vit_ckpt_layer: 5
|
| 18 |
-
# batch_size: 16
|
| 19 |
-
# init_lr: 2e-6
|
| 20 |
-
|
| 21 |
-
image_size: 384
|
| 22 |
-
|
| 23 |
-
# generation configs
|
| 24 |
-
max_length: 20
|
| 25 |
-
min_length: 5
|
| 26 |
-
num_beams: 3
|
| 27 |
-
prompt: 'a picture of '
|
| 28 |
-
|
| 29 |
-
# optimizer
|
| 30 |
-
weight_decay: 0.05
|
| 31 |
-
min_lr: 0
|
| 32 |
-
max_epoch: 5
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/med_config.json
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"architectures": [
|
| 3 |
-
"BertModel"
|
| 4 |
-
],
|
| 5 |
-
"attention_probs_dropout_prob": 0.1,
|
| 6 |
-
"hidden_act": "gelu",
|
| 7 |
-
"hidden_dropout_prob": 0.1,
|
| 8 |
-
"hidden_size": 768,
|
| 9 |
-
"initializer_range": 0.02,
|
| 10 |
-
"intermediate_size": 3072,
|
| 11 |
-
"layer_norm_eps": 1e-12,
|
| 12 |
-
"max_position_embeddings": 512,
|
| 13 |
-
"model_type": "bert",
|
| 14 |
-
"num_attention_heads": 12,
|
| 15 |
-
"num_hidden_layers": 12,
|
| 16 |
-
"pad_token_id": 0,
|
| 17 |
-
"type_vocab_size": 2,
|
| 18 |
-
"vocab_size": 30524,
|
| 19 |
-
"encoder_width": 768,
|
| 20 |
-
"add_cross_attention": true
|
| 21 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/nlvr.yaml
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
image_root: '/export/share/datasets/vision/NLVR2/'
|
| 2 |
-
ann_root: 'annotation'
|
| 3 |
-
|
| 4 |
-
# set pretrained as a file path or an url
|
| 5 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
| 6 |
-
|
| 7 |
-
#size of vit model; base or large
|
| 8 |
-
vit: 'base'
|
| 9 |
-
batch_size_train: 16
|
| 10 |
-
batch_size_test: 64
|
| 11 |
-
vit_grad_ckpt: False
|
| 12 |
-
vit_ckpt_layer: 0
|
| 13 |
-
max_epoch: 15
|
| 14 |
-
|
| 15 |
-
image_size: 384
|
| 16 |
-
|
| 17 |
-
# optimizer
|
| 18 |
-
weight_decay: 0.05
|
| 19 |
-
init_lr: 3e-5
|
| 20 |
-
min_lr: 0
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/nocaps.yaml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
image_root: '/export/share/datasets/vision/nocaps/'
|
| 2 |
-
ann_root: 'annotation'
|
| 3 |
-
|
| 4 |
-
# set pretrained as a file path or an url
|
| 5 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
| 6 |
-
|
| 7 |
-
vit: 'base'
|
| 8 |
-
batch_size: 32
|
| 9 |
-
|
| 10 |
-
image_size: 384
|
| 11 |
-
|
| 12 |
-
max_length: 20
|
| 13 |
-
min_length: 5
|
| 14 |
-
num_beams: 3
|
| 15 |
-
prompt: 'a picture of '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/pretrain.yaml
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
| 2 |
-
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
| 3 |
-
]
|
| 4 |
-
laion_path: ''
|
| 5 |
-
|
| 6 |
-
# size of vit model; base or large
|
| 7 |
-
vit: 'base'
|
| 8 |
-
vit_grad_ckpt: False
|
| 9 |
-
vit_ckpt_layer: 0
|
| 10 |
-
|
| 11 |
-
image_size: 224
|
| 12 |
-
batch_size: 75
|
| 13 |
-
|
| 14 |
-
queue_size: 57600
|
| 15 |
-
alpha: 0.4
|
| 16 |
-
|
| 17 |
-
# optimizer
|
| 18 |
-
weight_decay: 0.05
|
| 19 |
-
init_lr: 3e-4
|
| 20 |
-
min_lr: 1e-6
|
| 21 |
-
warmup_lr: 1e-6
|
| 22 |
-
lr_decay_rate: 0.9
|
| 23 |
-
max_epoch: 20
|
| 24 |
-
warmup_steps: 3000
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/retrieval_coco.yaml
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
image_root: '/export/share/datasets/vision/coco/images/'
|
| 2 |
-
ann_root: 'annotation'
|
| 3 |
-
dataset: 'coco'
|
| 4 |
-
|
| 5 |
-
# set pretrained as a file path or an url
|
| 6 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
| 7 |
-
|
| 8 |
-
# size of vit model; base or large
|
| 9 |
-
|
| 10 |
-
vit: 'base'
|
| 11 |
-
batch_size_train: 32
|
| 12 |
-
batch_size_test: 64
|
| 13 |
-
vit_grad_ckpt: True
|
| 14 |
-
vit_ckpt_layer: 4
|
| 15 |
-
init_lr: 1e-5
|
| 16 |
-
|
| 17 |
-
# vit: 'large'
|
| 18 |
-
# batch_size_train: 16
|
| 19 |
-
# batch_size_test: 32
|
| 20 |
-
# vit_grad_ckpt: True
|
| 21 |
-
# vit_ckpt_layer: 12
|
| 22 |
-
# init_lr: 5e-6
|
| 23 |
-
|
| 24 |
-
image_size: 384
|
| 25 |
-
queue_size: 57600
|
| 26 |
-
alpha: 0.4
|
| 27 |
-
k_test: 256
|
| 28 |
-
negative_all_rank: True
|
| 29 |
-
|
| 30 |
-
# optimizer
|
| 31 |
-
weight_decay: 0.05
|
| 32 |
-
min_lr: 0
|
| 33 |
-
max_epoch: 6
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/retrieval_flickr.yaml
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
image_root: '/export/share/datasets/vision/flickr30k/'
|
| 2 |
-
ann_root: 'annotation'
|
| 3 |
-
dataset: 'flickr'
|
| 4 |
-
|
| 5 |
-
# set pretrained as a file path or an url
|
| 6 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
| 7 |
-
|
| 8 |
-
# size of vit model; base or large
|
| 9 |
-
|
| 10 |
-
vit: 'base'
|
| 11 |
-
batch_size_train: 32
|
| 12 |
-
batch_size_test: 64
|
| 13 |
-
vit_grad_ckpt: True
|
| 14 |
-
vit_ckpt_layer: 4
|
| 15 |
-
init_lr: 1e-5
|
| 16 |
-
|
| 17 |
-
# vit: 'large'
|
| 18 |
-
# batch_size_train: 16
|
| 19 |
-
# batch_size_test: 32
|
| 20 |
-
# vit_grad_ckpt: True
|
| 21 |
-
# vit_ckpt_layer: 10
|
| 22 |
-
# init_lr: 5e-6
|
| 23 |
-
|
| 24 |
-
image_size: 384
|
| 25 |
-
queue_size: 57600
|
| 26 |
-
alpha: 0.4
|
| 27 |
-
k_test: 128
|
| 28 |
-
negative_all_rank: False
|
| 29 |
-
|
| 30 |
-
# optimizer
|
| 31 |
-
weight_decay: 0.05
|
| 32 |
-
min_lr: 0
|
| 33 |
-
max_epoch: 6
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/retrieval_msrvtt.yaml
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
|
| 2 |
-
ann_root: 'annotation'
|
| 3 |
-
|
| 4 |
-
# set pretrained as a file path or an url
|
| 5 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
| 6 |
-
|
| 7 |
-
# size of vit model; base or large
|
| 8 |
-
vit: 'base'
|
| 9 |
-
batch_size: 64
|
| 10 |
-
k_test: 128
|
| 11 |
-
image_size: 384
|
| 12 |
-
num_frm_test: 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/configs/vqa.yaml
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
| 2 |
-
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
| 3 |
-
train_files: ['vqa_train','vqa_val','vg_qa']
|
| 4 |
-
ann_root: 'annotation'
|
| 5 |
-
|
| 6 |
-
# set pretrained as a file path or an url
|
| 7 |
-
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
|
| 8 |
-
|
| 9 |
-
# size of vit model; base or large
|
| 10 |
-
vit: 'base'
|
| 11 |
-
batch_size_train: 16
|
| 12 |
-
batch_size_test: 32
|
| 13 |
-
vit_grad_ckpt: False
|
| 14 |
-
vit_ckpt_layer: 0
|
| 15 |
-
init_lr: 2e-5
|
| 16 |
-
|
| 17 |
-
image_size: 480
|
| 18 |
-
|
| 19 |
-
k_test: 128
|
| 20 |
-
inference: 'rank'
|
| 21 |
-
|
| 22 |
-
# optimizer
|
| 23 |
-
weight_decay: 0.05
|
| 24 |
-
min_lr: 0
|
| 25 |
-
max_epoch: 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/config.json
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"architectures": [
|
| 3 |
-
"BertForMaskedLM"
|
| 4 |
-
],
|
| 5 |
-
"attention_probs_dropout_prob": 0.1,
|
| 6 |
-
"gradient_checkpointing": false,
|
| 7 |
-
"hidden_act": "gelu",
|
| 8 |
-
"hidden_dropout_prob": 0.1,
|
| 9 |
-
"hidden_size": 768,
|
| 10 |
-
"initializer_range": 0.02,
|
| 11 |
-
"intermediate_size": 3072,
|
| 12 |
-
"layer_norm_eps": 1e-12,
|
| 13 |
-
"max_position_embeddings": 512,
|
| 14 |
-
"model_type": "bert",
|
| 15 |
-
"num_attention_heads": 12,
|
| 16 |
-
"num_hidden_layers": 12,
|
| 17 |
-
"pad_token_id": 0,
|
| 18 |
-
"position_embedding_type": "absolute",
|
| 19 |
-
"transformers_version": "4.6.0.dev0",
|
| 20 |
-
"type_vocab_size": 2,
|
| 21 |
-
"use_cache": true,
|
| 22 |
-
"vocab_size": 30522
|
| 23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/tokenizer_config.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"do_lower_case": true
|
| 3 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/bert_tokenizer/vocab.txt
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
repositories/Fooocus/extras/BLIP/models/blip.py
DELETED
|
@@ -1,239 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
-
* All rights reserved.
|
| 4 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
-
* By Junnan Li
|
| 7 |
-
'''
|
| 8 |
-
import warnings
|
| 9 |
-
warnings.filterwarnings("ignore")
|
| 10 |
-
|
| 11 |
-
from extras.BLIP.models.vit import VisionTransformer, interpolate_pos_embed
|
| 12 |
-
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
|
| 13 |
-
from transformers import BertTokenizer
|
| 14 |
-
|
| 15 |
-
import torch
|
| 16 |
-
from torch import nn
|
| 17 |
-
import torch.nn.functional as F
|
| 18 |
-
|
| 19 |
-
import os
|
| 20 |
-
from urllib.parse import urlparse
|
| 21 |
-
from timm.models.hub import download_cached_file
|
| 22 |
-
|
| 23 |
-
class BLIP_Base(nn.Module):
|
| 24 |
-
def __init__(self,
|
| 25 |
-
med_config = 'configs/med_config.json',
|
| 26 |
-
image_size = 224,
|
| 27 |
-
vit = 'base',
|
| 28 |
-
vit_grad_ckpt = False,
|
| 29 |
-
vit_ckpt_layer = 0,
|
| 30 |
-
):
|
| 31 |
-
"""
|
| 32 |
-
Args:
|
| 33 |
-
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 34 |
-
image_size (int): input image size
|
| 35 |
-
vit (str): model size of vision transformer
|
| 36 |
-
"""
|
| 37 |
-
super().__init__()
|
| 38 |
-
|
| 39 |
-
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 40 |
-
self.tokenizer = init_tokenizer()
|
| 41 |
-
med_config = BertConfig.from_json_file(med_config)
|
| 42 |
-
med_config.encoder_width = vision_width
|
| 43 |
-
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def forward(self, image, caption, mode):
|
| 47 |
-
|
| 48 |
-
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
| 49 |
-
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
| 50 |
-
|
| 51 |
-
if mode=='image':
|
| 52 |
-
# return image features
|
| 53 |
-
image_embeds = self.visual_encoder(image)
|
| 54 |
-
return image_embeds
|
| 55 |
-
|
| 56 |
-
elif mode=='text':
|
| 57 |
-
# return text features
|
| 58 |
-
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 59 |
-
return_dict = True, mode = 'text')
|
| 60 |
-
return text_output.last_hidden_state
|
| 61 |
-
|
| 62 |
-
elif mode=='multimodal':
|
| 63 |
-
# return multimodel features
|
| 64 |
-
image_embeds = self.visual_encoder(image)
|
| 65 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 66 |
-
|
| 67 |
-
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 68 |
-
output = self.text_encoder(text.input_ids,
|
| 69 |
-
attention_mask = text.attention_mask,
|
| 70 |
-
encoder_hidden_states = image_embeds,
|
| 71 |
-
encoder_attention_mask = image_atts,
|
| 72 |
-
return_dict = True,
|
| 73 |
-
)
|
| 74 |
-
return output.last_hidden_state
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class BLIP_Decoder(nn.Module):
|
| 79 |
-
def __init__(self,
|
| 80 |
-
med_config = 'configs/med_config.json',
|
| 81 |
-
image_size = 384,
|
| 82 |
-
vit = 'base',
|
| 83 |
-
vit_grad_ckpt = False,
|
| 84 |
-
vit_ckpt_layer = 0,
|
| 85 |
-
prompt = 'a picture of ',
|
| 86 |
-
):
|
| 87 |
-
"""
|
| 88 |
-
Args:
|
| 89 |
-
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 90 |
-
image_size (int): input image size
|
| 91 |
-
vit (str): model size of vision transformer
|
| 92 |
-
"""
|
| 93 |
-
super().__init__()
|
| 94 |
-
|
| 95 |
-
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 96 |
-
self.tokenizer = init_tokenizer()
|
| 97 |
-
med_config = BertConfig.from_json_file(med_config)
|
| 98 |
-
med_config.encoder_width = vision_width
|
| 99 |
-
self.text_decoder = BertLMHeadModel(config=med_config)
|
| 100 |
-
|
| 101 |
-
self.prompt = prompt
|
| 102 |
-
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def forward(self, image, caption):
|
| 106 |
-
|
| 107 |
-
image_embeds = self.visual_encoder(image)
|
| 108 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 109 |
-
|
| 110 |
-
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
| 111 |
-
|
| 112 |
-
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 113 |
-
|
| 114 |
-
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
| 115 |
-
decoder_targets[:,:self.prompt_length] = -100
|
| 116 |
-
|
| 117 |
-
decoder_output = self.text_decoder(text.input_ids,
|
| 118 |
-
attention_mask = text.attention_mask,
|
| 119 |
-
encoder_hidden_states = image_embeds,
|
| 120 |
-
encoder_attention_mask = image_atts,
|
| 121 |
-
labels = decoder_targets,
|
| 122 |
-
return_dict = True,
|
| 123 |
-
)
|
| 124 |
-
loss_lm = decoder_output.loss
|
| 125 |
-
|
| 126 |
-
return loss_lm
|
| 127 |
-
|
| 128 |
-
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
| 129 |
-
image_embeds = self.visual_encoder(image)
|
| 130 |
-
|
| 131 |
-
if not sample:
|
| 132 |
-
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
| 133 |
-
|
| 134 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 135 |
-
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
| 136 |
-
|
| 137 |
-
prompt = [self.prompt] * image.size(0)
|
| 138 |
-
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
| 139 |
-
input_ids[:,0] = self.tokenizer.bos_token_id
|
| 140 |
-
input_ids = input_ids[:, :-1]
|
| 141 |
-
|
| 142 |
-
if sample:
|
| 143 |
-
#nucleus sampling
|
| 144 |
-
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 145 |
-
max_length=max_length,
|
| 146 |
-
min_length=min_length,
|
| 147 |
-
do_sample=True,
|
| 148 |
-
top_p=top_p,
|
| 149 |
-
num_return_sequences=1,
|
| 150 |
-
eos_token_id=self.tokenizer.sep_token_id,
|
| 151 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 152 |
-
repetition_penalty=1.1,
|
| 153 |
-
**model_kwargs)
|
| 154 |
-
else:
|
| 155 |
-
#beam search
|
| 156 |
-
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 157 |
-
max_length=max_length,
|
| 158 |
-
min_length=min_length,
|
| 159 |
-
num_beams=num_beams,
|
| 160 |
-
eos_token_id=self.tokenizer.sep_token_id,
|
| 161 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 162 |
-
repetition_penalty=repetition_penalty,
|
| 163 |
-
**model_kwargs)
|
| 164 |
-
|
| 165 |
-
captions = []
|
| 166 |
-
for output in outputs:
|
| 167 |
-
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 168 |
-
captions.append(caption[len(self.prompt):])
|
| 169 |
-
return captions
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def blip_decoder(pretrained='',**kwargs):
|
| 173 |
-
model = BLIP_Decoder(**kwargs)
|
| 174 |
-
if pretrained:
|
| 175 |
-
model,msg = load_checkpoint(model,pretrained)
|
| 176 |
-
assert(len(msg.missing_keys)==0)
|
| 177 |
-
return model
|
| 178 |
-
|
| 179 |
-
def blip_feature_extractor(pretrained='',**kwargs):
|
| 180 |
-
model = BLIP_Base(**kwargs)
|
| 181 |
-
if pretrained:
|
| 182 |
-
model,msg = load_checkpoint(model,pretrained)
|
| 183 |
-
assert(len(msg.missing_keys)==0)
|
| 184 |
-
return model
|
| 185 |
-
|
| 186 |
-
def init_tokenizer():
|
| 187 |
-
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bert_tokenizer")
|
| 188 |
-
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
| 189 |
-
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 190 |
-
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 191 |
-
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 192 |
-
return tokenizer
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 196 |
-
|
| 197 |
-
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 198 |
-
if vit=='base':
|
| 199 |
-
vision_width = 768
|
| 200 |
-
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 201 |
-
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 202 |
-
drop_path_rate=0 or drop_path_rate
|
| 203 |
-
)
|
| 204 |
-
elif vit=='large':
|
| 205 |
-
vision_width = 1024
|
| 206 |
-
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 207 |
-
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 208 |
-
drop_path_rate=0.1 or drop_path_rate
|
| 209 |
-
)
|
| 210 |
-
return visual_encoder, vision_width
|
| 211 |
-
|
| 212 |
-
def is_url(url_or_filename):
|
| 213 |
-
parsed = urlparse(url_or_filename)
|
| 214 |
-
return parsed.scheme in ("http", "https")
|
| 215 |
-
|
| 216 |
-
def load_checkpoint(model,url_or_filename):
|
| 217 |
-
if is_url(url_or_filename):
|
| 218 |
-
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 219 |
-
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 220 |
-
elif os.path.isfile(url_or_filename):
|
| 221 |
-
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 222 |
-
else:
|
| 223 |
-
raise RuntimeError('checkpoint url or path is invalid')
|
| 224 |
-
|
| 225 |
-
state_dict = checkpoint['model']
|
| 226 |
-
|
| 227 |
-
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 228 |
-
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 229 |
-
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 230 |
-
model.visual_encoder_m)
|
| 231 |
-
for key in model.state_dict().keys():
|
| 232 |
-
if key in state_dict.keys():
|
| 233 |
-
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 234 |
-
del state_dict[key]
|
| 235 |
-
|
| 236 |
-
msg = model.load_state_dict(state_dict,strict=False)
|
| 237 |
-
print('load checkpoint from %s'%url_or_filename)
|
| 238 |
-
return model,msg
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/blip_itm.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
from extras.BLIP.models.med import BertConfig, BertModel
|
| 2 |
-
from transformers import BertTokenizer
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 9 |
-
|
| 10 |
-
class BLIP_ITM(nn.Module):
|
| 11 |
-
def __init__(self,
|
| 12 |
-
med_config = 'configs/med_config.json',
|
| 13 |
-
image_size = 384,
|
| 14 |
-
vit = 'base',
|
| 15 |
-
vit_grad_ckpt = False,
|
| 16 |
-
vit_ckpt_layer = 0,
|
| 17 |
-
embed_dim = 256,
|
| 18 |
-
):
|
| 19 |
-
"""
|
| 20 |
-
Args:
|
| 21 |
-
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 22 |
-
image_size (int): input image size
|
| 23 |
-
vit (str): model size of vision transformer
|
| 24 |
-
"""
|
| 25 |
-
super().__init__()
|
| 26 |
-
|
| 27 |
-
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 28 |
-
self.tokenizer = init_tokenizer()
|
| 29 |
-
med_config = BertConfig.from_json_file(med_config)
|
| 30 |
-
med_config.encoder_width = vision_width
|
| 31 |
-
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 32 |
-
|
| 33 |
-
text_width = self.text_encoder.config.hidden_size
|
| 34 |
-
|
| 35 |
-
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 36 |
-
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 37 |
-
|
| 38 |
-
self.itm_head = nn.Linear(text_width, 2)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def forward(self, image, caption, match_head='itm'):
|
| 42 |
-
|
| 43 |
-
image_embeds = self.visual_encoder(image)
|
| 44 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 45 |
-
|
| 46 |
-
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
| 47 |
-
return_tensors="pt").to(image.device)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
if match_head=='itm':
|
| 51 |
-
output = self.text_encoder(text.input_ids,
|
| 52 |
-
attention_mask = text.attention_mask,
|
| 53 |
-
encoder_hidden_states = image_embeds,
|
| 54 |
-
encoder_attention_mask = image_atts,
|
| 55 |
-
return_dict = True,
|
| 56 |
-
)
|
| 57 |
-
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
|
| 58 |
-
return itm_output
|
| 59 |
-
|
| 60 |
-
elif match_head=='itc':
|
| 61 |
-
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 62 |
-
return_dict = True, mode = 'text')
|
| 63 |
-
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 64 |
-
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 65 |
-
|
| 66 |
-
sim = image_feat @ text_feat.t()
|
| 67 |
-
return sim
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def blip_itm(pretrained='',**kwargs):
|
| 71 |
-
model = BLIP_ITM(**kwargs)
|
| 72 |
-
if pretrained:
|
| 73 |
-
model,msg = load_checkpoint(model,pretrained)
|
| 74 |
-
assert(len(msg.missing_keys)==0)
|
| 75 |
-
return model
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/blip_nlvr.py
DELETED
|
@@ -1,105 +0,0 @@
|
|
| 1 |
-
from extras.BLIP.models.med import BertConfig
|
| 2 |
-
from extras.BLIP.models.nlvr_encoder import BertModel
|
| 3 |
-
from extras.BLIP.models.vit import interpolate_pos_embed
|
| 4 |
-
from extras.BLIP.models.blip import create_vit, init_tokenizer, is_url
|
| 5 |
-
|
| 6 |
-
from timm.models.hub import download_cached_file
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
from transformers import BertTokenizer
|
| 12 |
-
import numpy as np
|
| 13 |
-
import os
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BLIP_NLVR(nn.Module):
|
| 17 |
-
def __init__(self,
|
| 18 |
-
med_config = 'configs/med_config.json',
|
| 19 |
-
image_size = 480,
|
| 20 |
-
vit = 'base',
|
| 21 |
-
vit_grad_ckpt = False,
|
| 22 |
-
vit_ckpt_layer = 0,
|
| 23 |
-
):
|
| 24 |
-
"""
|
| 25 |
-
Args:
|
| 26 |
-
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 27 |
-
image_size (int): input image size
|
| 28 |
-
vit (str): model size of vision transformer
|
| 29 |
-
"""
|
| 30 |
-
super().__init__()
|
| 31 |
-
|
| 32 |
-
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
|
| 33 |
-
self.tokenizer = init_tokenizer()
|
| 34 |
-
med_config = BertConfig.from_json_file(med_config)
|
| 35 |
-
med_config.encoder_width = vision_width
|
| 36 |
-
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 37 |
-
|
| 38 |
-
self.cls_head = nn.Sequential(
|
| 39 |
-
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
| 40 |
-
nn.ReLU(),
|
| 41 |
-
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
def forward(self, image, text, targets, train=True):
|
| 45 |
-
|
| 46 |
-
image_embeds = self.visual_encoder(image)
|
| 47 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 48 |
-
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
|
| 49 |
-
|
| 50 |
-
text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
|
| 51 |
-
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 52 |
-
|
| 53 |
-
output = self.text_encoder(text.input_ids,
|
| 54 |
-
attention_mask = text.attention_mask,
|
| 55 |
-
encoder_hidden_states = [image0_embeds,image1_embeds],
|
| 56 |
-
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
|
| 57 |
-
image_atts[image0_embeds.size(0):]],
|
| 58 |
-
return_dict = True,
|
| 59 |
-
)
|
| 60 |
-
hidden_state = output.last_hidden_state[:,0,:]
|
| 61 |
-
prediction = self.cls_head(hidden_state)
|
| 62 |
-
|
| 63 |
-
if train:
|
| 64 |
-
loss = F.cross_entropy(prediction, targets)
|
| 65 |
-
return loss
|
| 66 |
-
else:
|
| 67 |
-
return prediction
|
| 68 |
-
|
| 69 |
-
def blip_nlvr(pretrained='',**kwargs):
|
| 70 |
-
model = BLIP_NLVR(**kwargs)
|
| 71 |
-
if pretrained:
|
| 72 |
-
model,msg = load_checkpoint(model,pretrained)
|
| 73 |
-
print("missing keys:")
|
| 74 |
-
print(msg.missing_keys)
|
| 75 |
-
return model
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def load_checkpoint(model,url_or_filename):
|
| 79 |
-
if is_url(url_or_filename):
|
| 80 |
-
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 81 |
-
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 82 |
-
elif os.path.isfile(url_or_filename):
|
| 83 |
-
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 84 |
-
else:
|
| 85 |
-
raise RuntimeError('checkpoint url or path is invalid')
|
| 86 |
-
state_dict = checkpoint['model']
|
| 87 |
-
|
| 88 |
-
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 89 |
-
|
| 90 |
-
for key in list(state_dict.keys()):
|
| 91 |
-
if 'crossattention.self.' in key:
|
| 92 |
-
new_key0 = key.replace('self','self0')
|
| 93 |
-
new_key1 = key.replace('self','self1')
|
| 94 |
-
state_dict[new_key0] = state_dict[key]
|
| 95 |
-
state_dict[new_key1] = state_dict[key]
|
| 96 |
-
elif 'crossattention.output.dense.' in key:
|
| 97 |
-
new_key0 = key.replace('dense','dense0')
|
| 98 |
-
new_key1 = key.replace('dense','dense1')
|
| 99 |
-
state_dict[new_key0] = state_dict[key]
|
| 100 |
-
state_dict[new_key1] = state_dict[key]
|
| 101 |
-
|
| 102 |
-
msg = model.load_state_dict(state_dict,strict=False)
|
| 103 |
-
print('load checkpoint from %s'%url_or_filename)
|
| 104 |
-
return model,msg
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/blip_pretrain.py
DELETED
|
@@ -1,339 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
-
* All rights reserved.
|
| 4 |
-
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
-
* By Junnan Li
|
| 7 |
-
'''
|
| 8 |
-
from extras.BLIP.models.med import BertConfig, BertModel, BertLMHeadModel
|
| 9 |
-
from transformers import BertTokenizer
|
| 10 |
-
import transformers
|
| 11 |
-
transformers.logging.set_verbosity_error()
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
from torch import nn
|
| 15 |
-
import torch.nn.functional as F
|
| 16 |
-
|
| 17 |
-
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 18 |
-
|
| 19 |
-
class BLIP_Pretrain(nn.Module):
|
| 20 |
-
def __init__(self,
|
| 21 |
-
med_config = 'configs/bert_config.json',
|
| 22 |
-
image_size = 224,
|
| 23 |
-
vit = 'base',
|
| 24 |
-
vit_grad_ckpt = False,
|
| 25 |
-
vit_ckpt_layer = 0,
|
| 26 |
-
embed_dim = 256,
|
| 27 |
-
queue_size = 57600,
|
| 28 |
-
momentum = 0.995,
|
| 29 |
-
):
|
| 30 |
-
"""
|
| 31 |
-
Args:
|
| 32 |
-
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 33 |
-
image_size (int): input image size
|
| 34 |
-
vit (str): model size of vision transformer
|
| 35 |
-
"""
|
| 36 |
-
super().__init__()
|
| 37 |
-
|
| 38 |
-
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
| 39 |
-
|
| 40 |
-
if vit=='base':
|
| 41 |
-
checkpoint = torch.hub.load_state_dict_from_url(
|
| 42 |
-
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 43 |
-
map_location="cpu", check_hash=True)
|
| 44 |
-
state_dict = checkpoint["model"]
|
| 45 |
-
msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
|
| 46 |
-
elif vit=='large':
|
| 47 |
-
from timm.models.helpers import load_custom_pretrained
|
| 48 |
-
from timm.models.vision_transformer import default_cfgs
|
| 49 |
-
load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
|
| 50 |
-
|
| 51 |
-
self.tokenizer = init_tokenizer()
|
| 52 |
-
encoder_config = BertConfig.from_json_file(med_config)
|
| 53 |
-
encoder_config.encoder_width = vision_width
|
| 54 |
-
self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
|
| 55 |
-
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
| 56 |
-
|
| 57 |
-
text_width = self.text_encoder.config.hidden_size
|
| 58 |
-
|
| 59 |
-
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 60 |
-
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 61 |
-
|
| 62 |
-
self.itm_head = nn.Linear(text_width, 2)
|
| 63 |
-
|
| 64 |
-
# create momentum encoders
|
| 65 |
-
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
| 66 |
-
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
| 67 |
-
self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 68 |
-
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
| 69 |
-
|
| 70 |
-
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
| 71 |
-
[self.vision_proj,self.vision_proj_m],
|
| 72 |
-
[self.text_encoder,self.text_encoder_m],
|
| 73 |
-
[self.text_proj,self.text_proj_m],
|
| 74 |
-
]
|
| 75 |
-
self.copy_params()
|
| 76 |
-
|
| 77 |
-
# create the queue
|
| 78 |
-
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
| 79 |
-
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
| 80 |
-
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
| 81 |
-
|
| 82 |
-
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
| 83 |
-
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
| 84 |
-
|
| 85 |
-
self.queue_size = queue_size
|
| 86 |
-
self.momentum = momentum
|
| 87 |
-
self.temp = nn.Parameter(0.07*torch.ones([]))
|
| 88 |
-
|
| 89 |
-
# create the decoder
|
| 90 |
-
decoder_config = BertConfig.from_json_file(med_config)
|
| 91 |
-
decoder_config.encoder_width = vision_width
|
| 92 |
-
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
|
| 93 |
-
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
|
| 94 |
-
tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def forward(self, image, caption, alpha):
|
| 98 |
-
with torch.no_grad():
|
| 99 |
-
self.temp.clamp_(0.001,0.5)
|
| 100 |
-
|
| 101 |
-
image_embeds = self.visual_encoder(image)
|
| 102 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 103 |
-
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 104 |
-
|
| 105 |
-
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
|
| 106 |
-
return_tensors="pt").to(image.device)
|
| 107 |
-
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 108 |
-
return_dict = True, mode = 'text')
|
| 109 |
-
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 110 |
-
|
| 111 |
-
# get momentum features
|
| 112 |
-
with torch.no_grad():
|
| 113 |
-
self._momentum_update()
|
| 114 |
-
image_embeds_m = self.visual_encoder_m(image)
|
| 115 |
-
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
| 116 |
-
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
| 117 |
-
|
| 118 |
-
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
| 119 |
-
return_dict = True, mode = 'text')
|
| 120 |
-
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
| 121 |
-
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
| 122 |
-
|
| 123 |
-
sim_i2t_m = image_feat_m @ text_feat_all / self.temp
|
| 124 |
-
sim_t2i_m = text_feat_m @ image_feat_all / self.temp
|
| 125 |
-
|
| 126 |
-
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
|
| 127 |
-
sim_targets.fill_diagonal_(1)
|
| 128 |
-
|
| 129 |
-
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
| 130 |
-
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
| 131 |
-
|
| 132 |
-
sim_i2t = image_feat @ text_feat_all / self.temp
|
| 133 |
-
sim_t2i = text_feat @ image_feat_all / self.temp
|
| 134 |
-
|
| 135 |
-
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
| 136 |
-
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
| 137 |
-
|
| 138 |
-
loss_ita = (loss_i2t+loss_t2i)/2
|
| 139 |
-
|
| 140 |
-
self._dequeue_and_enqueue(image_feat_m, text_feat_m)
|
| 141 |
-
|
| 142 |
-
###============== Image-text Matching ===================###
|
| 143 |
-
encoder_input_ids = text.input_ids.clone()
|
| 144 |
-
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
| 145 |
-
|
| 146 |
-
# forward the positve image-text pair
|
| 147 |
-
bs = image.size(0)
|
| 148 |
-
output_pos = self.text_encoder(encoder_input_ids,
|
| 149 |
-
attention_mask = text.attention_mask,
|
| 150 |
-
encoder_hidden_states = image_embeds,
|
| 151 |
-
encoder_attention_mask = image_atts,
|
| 152 |
-
return_dict = True,
|
| 153 |
-
)
|
| 154 |
-
with torch.no_grad():
|
| 155 |
-
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
|
| 156 |
-
weights_t2i.fill_diagonal_(0)
|
| 157 |
-
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
|
| 158 |
-
weights_i2t.fill_diagonal_(0)
|
| 159 |
-
|
| 160 |
-
# select a negative image for each text
|
| 161 |
-
image_embeds_neg = []
|
| 162 |
-
for b in range(bs):
|
| 163 |
-
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 164 |
-
image_embeds_neg.append(image_embeds[neg_idx])
|
| 165 |
-
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 166 |
-
|
| 167 |
-
# select a negative text for each image
|
| 168 |
-
text_ids_neg = []
|
| 169 |
-
text_atts_neg = []
|
| 170 |
-
for b in range(bs):
|
| 171 |
-
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 172 |
-
text_ids_neg.append(encoder_input_ids[neg_idx])
|
| 173 |
-
text_atts_neg.append(text.attention_mask[neg_idx])
|
| 174 |
-
|
| 175 |
-
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
| 176 |
-
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
| 177 |
-
|
| 178 |
-
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
| 179 |
-
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
| 180 |
-
|
| 181 |
-
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
| 182 |
-
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
| 183 |
-
|
| 184 |
-
output_neg = self.text_encoder(text_ids_all,
|
| 185 |
-
attention_mask = text_atts_all,
|
| 186 |
-
encoder_hidden_states = image_embeds_all,
|
| 187 |
-
encoder_attention_mask = image_atts_all,
|
| 188 |
-
return_dict = True,
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
| 192 |
-
vl_output = self.itm_head(vl_embeddings)
|
| 193 |
-
|
| 194 |
-
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
| 195 |
-
dim=0).to(image.device)
|
| 196 |
-
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
| 197 |
-
|
| 198 |
-
##================= LM ========================##
|
| 199 |
-
decoder_input_ids = text.input_ids.clone()
|
| 200 |
-
decoder_input_ids[:,0] = self.tokenizer.bos_token_id
|
| 201 |
-
decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
|
| 202 |
-
|
| 203 |
-
decoder_output = self.text_decoder(decoder_input_ids,
|
| 204 |
-
attention_mask = text.attention_mask,
|
| 205 |
-
encoder_hidden_states = image_embeds,
|
| 206 |
-
encoder_attention_mask = image_atts,
|
| 207 |
-
labels = decoder_targets,
|
| 208 |
-
return_dict = True,
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
loss_lm = decoder_output.loss
|
| 212 |
-
return loss_ita, loss_itm, loss_lm
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
@torch.no_grad()
|
| 217 |
-
def copy_params(self):
|
| 218 |
-
for model_pair in self.model_pairs:
|
| 219 |
-
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 220 |
-
param_m.data.copy_(param.data) # initialize
|
| 221 |
-
param_m.requires_grad = False # not update by gradient
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
@torch.no_grad()
|
| 225 |
-
def _momentum_update(self):
|
| 226 |
-
for model_pair in self.model_pairs:
|
| 227 |
-
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 228 |
-
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
@torch.no_grad()
|
| 232 |
-
def _dequeue_and_enqueue(self, image_feat, text_feat):
|
| 233 |
-
# gather keys before updating queue
|
| 234 |
-
image_feats = concat_all_gather(image_feat)
|
| 235 |
-
text_feats = concat_all_gather(text_feat)
|
| 236 |
-
|
| 237 |
-
batch_size = image_feats.shape[0]
|
| 238 |
-
|
| 239 |
-
ptr = int(self.queue_ptr)
|
| 240 |
-
assert self.queue_size % batch_size == 0 # for simplicity
|
| 241 |
-
|
| 242 |
-
# replace the keys at ptr (dequeue and enqueue)
|
| 243 |
-
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
| 244 |
-
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
| 245 |
-
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
| 246 |
-
|
| 247 |
-
self.queue_ptr[0] = ptr
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def blip_pretrain(**kwargs):
|
| 251 |
-
model = BLIP_Pretrain(**kwargs)
|
| 252 |
-
return model
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
@torch.no_grad()
|
| 256 |
-
def concat_all_gather(tensor):
|
| 257 |
-
"""
|
| 258 |
-
Performs all_gather operation on the provided tensors.
|
| 259 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 260 |
-
"""
|
| 261 |
-
tensors_gather = [torch.ones_like(tensor)
|
| 262 |
-
for _ in range(torch.distributed.get_world_size())]
|
| 263 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 264 |
-
|
| 265 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 266 |
-
return output
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
from typing import List
|
| 270 |
-
def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
|
| 271 |
-
uninitialized_encoder_weights: List[str] = []
|
| 272 |
-
if decoder.__class__ != encoder.__class__:
|
| 273 |
-
print(
|
| 274 |
-
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
def tie_encoder_to_decoder_recursively(
|
| 278 |
-
decoder_pointer: nn.Module,
|
| 279 |
-
encoder_pointer: nn.Module,
|
| 280 |
-
module_name: str,
|
| 281 |
-
uninitialized_encoder_weights: List[str],
|
| 282 |
-
skip_key: str,
|
| 283 |
-
depth=0,
|
| 284 |
-
):
|
| 285 |
-
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
| 286 |
-
encoder_pointer, nn.Module
|
| 287 |
-
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
|
| 288 |
-
if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
|
| 289 |
-
assert hasattr(encoder_pointer, "weight")
|
| 290 |
-
encoder_pointer.weight = decoder_pointer.weight
|
| 291 |
-
if hasattr(decoder_pointer, "bias"):
|
| 292 |
-
assert hasattr(encoder_pointer, "bias")
|
| 293 |
-
encoder_pointer.bias = decoder_pointer.bias
|
| 294 |
-
print(module_name+' is tied')
|
| 295 |
-
return
|
| 296 |
-
|
| 297 |
-
encoder_modules = encoder_pointer._modules
|
| 298 |
-
decoder_modules = decoder_pointer._modules
|
| 299 |
-
if len(decoder_modules) > 0:
|
| 300 |
-
assert (
|
| 301 |
-
len(encoder_modules) > 0
|
| 302 |
-
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
| 303 |
-
|
| 304 |
-
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
|
| 305 |
-
encoder_layer_pos = 0
|
| 306 |
-
for name, module in decoder_modules.items():
|
| 307 |
-
if name.isdigit():
|
| 308 |
-
encoder_name = str(int(name) + encoder_layer_pos)
|
| 309 |
-
decoder_name = name
|
| 310 |
-
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
|
| 311 |
-
encoder_modules
|
| 312 |
-
) != len(decoder_modules):
|
| 313 |
-
# this can happen if the name corresponds to the position in a list module list of layers
|
| 314 |
-
# in this case the decoder has added a cross-attention that the encoder does not have
|
| 315 |
-
# thus skip this step and subtract one layer pos from encoder
|
| 316 |
-
encoder_layer_pos -= 1
|
| 317 |
-
continue
|
| 318 |
-
elif name not in encoder_modules:
|
| 319 |
-
continue
|
| 320 |
-
elif depth > 500:
|
| 321 |
-
raise ValueError(
|
| 322 |
-
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
|
| 323 |
-
)
|
| 324 |
-
else:
|
| 325 |
-
decoder_name = encoder_name = name
|
| 326 |
-
tie_encoder_to_decoder_recursively(
|
| 327 |
-
decoder_modules[decoder_name],
|
| 328 |
-
encoder_modules[encoder_name],
|
| 329 |
-
module_name + "/" + name,
|
| 330 |
-
uninitialized_encoder_weights,
|
| 331 |
-
skip_key,
|
| 332 |
-
depth=depth + 1,
|
| 333 |
-
)
|
| 334 |
-
all_encoder_weights.remove(module_name + "/" + encoder_name)
|
| 335 |
-
|
| 336 |
-
uninitialized_encoder_weights += list(all_encoder_weights)
|
| 337 |
-
|
| 338 |
-
# tie weights recursively
|
| 339 |
-
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repositories/Fooocus/extras/BLIP/models/blip_retrieval.py
DELETED
|
@@ -1,319 +0,0 @@
|
|
| 1 |
-
from extras.BLIP.models.med import BertConfig, BertModel
|
| 2 |
-
from transformers import BertTokenizer
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from extras.BLIP.models.blip import create_vit, init_tokenizer, load_checkpoint
|
| 9 |
-
|
| 10 |
-
class BLIP_Retrieval(nn.Module):
|
| 11 |
-
def __init__(self,
|
| 12 |
-
med_config = 'configs/med_config.json',
|
| 13 |
-
image_size = 384,
|
| 14 |
-
vit = 'base',
|
| 15 |
-
vit_grad_ckpt = False,
|
| 16 |
-
vit_ckpt_layer = 0,
|
| 17 |
-
embed_dim = 256,
|
| 18 |
-
queue_size = 57600,
|
| 19 |
-
momentum = 0.995,
|
| 20 |
-
negative_all_rank = False,
|
| 21 |
-
):
|
| 22 |
-
"""
|
| 23 |
-
Args:
|
| 24 |
-
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 25 |
-
image_size (int): input image size
|
| 26 |
-
vit (str): model size of vision transformer
|
| 27 |
-
"""
|
| 28 |
-
super().__init__()
|
| 29 |
-
|
| 30 |
-
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 31 |
-
self.tokenizer = init_tokenizer()
|
| 32 |
-
med_config = BertConfig.from_json_file(med_config)
|
| 33 |
-
med_config.encoder_width = vision_width
|
| 34 |
-
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 35 |
-
|
| 36 |
-
text_width = self.text_encoder.config.hidden_size
|
| 37 |
-
|
| 38 |
-
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 39 |
-
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 40 |
-
|
| 41 |
-
self.itm_head = nn.Linear(text_width, 2)
|
| 42 |
-
|
| 43 |
-
# create momentum encoders
|
| 44 |
-
self.visual_encoder_m, vision_width = create_vit(vit,image_size)
|
| 45 |
-
self.vision_proj_m = nn.Linear(vision_width, embed_dim)
|
| 46 |
-
self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
|
| 47 |
-
self.text_proj_m = nn.Linear(text_width, embed_dim)
|
| 48 |
-
|
| 49 |
-
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
| 50 |
-
[self.vision_proj,self.vision_proj_m],
|
| 51 |
-
[self.text_encoder,self.text_encoder_m],
|
| 52 |
-
[self.text_proj,self.text_proj_m],
|
| 53 |
-
]
|
| 54 |
-
self.copy_params()
|
| 55 |
-
|
| 56 |
-
# create the queue
|
| 57 |
-
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
|
| 58 |
-
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
|
| 59 |
-
self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
|
| 60 |
-
self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
|
| 61 |
-
|
| 62 |
-
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
|
| 63 |
-
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
|
| 64 |
-
|
| 65 |
-
self.queue_size = queue_size
|
| 66 |
-
self.momentum = momentum
|
| 67 |
-
self.temp = nn.Parameter(0.07*torch.ones([]))
|
| 68 |
-
|
| 69 |
-
self.negative_all_rank = negative_all_rank
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def forward(self, image, caption, alpha, idx):
|
| 73 |
-
with torch.no_grad():
|
| 74 |
-
self.temp.clamp_(0.001,0.5)
|
| 75 |
-
|
| 76 |
-
image_embeds = self.visual_encoder(image)
|
| 77 |
-
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 78 |
-
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
| 79 |
-
|
| 80 |
-
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
| 81 |
-
return_tensors="pt").to(image.device)
|
| 82 |
-
|
| 83 |
-
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 84 |
-
return_dict = True, mode = 'text')
|
| 85 |
-
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
| 86 |
-
|
| 87 |
-
###============== Image-text Contrastive Learning ===================###
|
| 88 |
-
idx = idx.view(-1,1)
|
| 89 |
-
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
|
| 90 |
-
pos_idx = torch.eq(idx, idx_all).float()
|
| 91 |
-
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
|
| 92 |
-
|
| 93 |
-
# get momentum features
|
| 94 |
-
with torch.no_grad():
|
| 95 |
-
self._momentum_update()
|
| 96 |
-
image_embeds_m = self.visual_encoder_m(image)
|
| 97 |
-
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
|
| 98 |
-
image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
|
| 99 |
-
|
| 100 |
-
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
|
| 101 |
-
return_dict = True, mode = 'text')
|
| 102 |
-
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
|
| 103 |
-
text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
|
| 104 |
-
|
| 105 |
-
sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
|
| 106 |
-
sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
|
| 107 |
-
|
| 108 |
-
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
|
| 109 |
-
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
|
| 110 |
-
|
| 111 |
-
sim_i2t = image_feat @ text_feat_m_all / self.temp
|
| 112 |
-
sim_t2i = text_feat @ image_feat_m_all / self.temp
|
| 113 |
-
|
| 114 |
-
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
|
| 115 |
-
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
|
| 116 |
-
|
| 117 |
-
loss_ita = (loss_i2t+loss_t2i)/2
|
| 118 |
-
|
| 119 |
-
idxs = concat_all_gather(idx)
|
| 120 |
-
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
|
| 121 |
-
|
| 122 |
-
###============== Image-text Matching ===================###
|
| 123 |
-
encoder_input_ids = text.input_ids.clone()
|
| 124 |
-
encoder_input_ids[:,0] = self.tokenizer.enc_token_id
|
| 125 |
-
|
| 126 |
-
# forward the positve image-text pair
|
| 127 |
-
bs = image.size(0)
|
| 128 |
-
output_pos = self.text_encoder(encoder_input_ids,
|
| 129 |
-
attention_mask = text.attention_mask,
|
| 130 |
-
encoder_hidden_states = image_embeds,
|
| 131 |
-
encoder_attention_mask = image_atts,
|
| 132 |
-
return_dict = True,
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
if self.negative_all_rank:
|
| 137 |
-
# compute sample similarity
|
| 138 |
-
with torch.no_grad():
|
| 139 |
-
mask = torch.eq(idx, idxs.t())
|
| 140 |
-
|
| 141 |
-
image_feat_world = concat_all_gather(image_feat)
|
| 142 |
-
text_feat_world = concat_all_gather(text_feat)
|
| 143 |
-
|
| 144 |
-
sim_i2t = image_feat @ text_feat_world.t() / self.temp
|
| 145 |
-
sim_t2i = text_feat @ image_feat_world.t() / self.temp
|
| 146 |
-
|
| 147 |
-
weights_i2t = F.softmax(sim_i2t,dim=1)
|
| 148 |
-
weights_i2t.masked_fill_(mask, 0)
|
| 149 |
-
|
| 150 |
-
weights_t2i = F.softmax(sim_t2i,dim=1)
|
| 151 |
-
weights_t2i.masked_fill_(mask, 0)
|
| 152 |
-
|
| 153 |
-
image_embeds_world = all_gather_with_grad(image_embeds)
|
| 154 |
-
|
| 155 |
-
# select a negative image (from all ranks) for each text
|
| 156 |
-
image_embeds_neg = []
|
| 157 |
-
for b in range(bs):
|
| 158 |
-
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 159 |
-
image_embeds_neg.append(image_embeds_world[neg_idx])
|
| 160 |
-
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 161 |
-
|
| 162 |
-
# select a negative text (from all ranks) for each image
|
| 163 |
-
input_ids_world = concat_all_gather(encoder_input_ids)
|
| 164 |
-
att_mask_world = concat_all_gather(text.attention_mask)
|
| 165 |
-
|
| 166 |
-
text_ids_neg = []
|
| 167 |
-
text_atts_neg = []
|
| 168 |
-
for b in range(bs):
|
| 169 |
-
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 170 |
-
text_ids_neg.append(input_ids_world[neg_idx])
|
| 171 |
-
text_atts_neg.append(att_mask_world[neg_idx])
|
| 172 |
-
|
| 173 |
-
else:
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
mask = torch.eq(idx, idx.t())
|
| 176 |
-
|
| 177 |
-
sim_i2t = image_feat @ text_feat.t() / self.temp
|
| 178 |
-
sim_t2i = text_feat @ image_feat.t() / self.temp
|
| 179 |
-
|
| 180 |
-
weights_i2t = F.softmax(sim_i2t,dim=1)
|
| 181 |
-
weights_i2t.masked_fill_(mask, 0)
|
| 182 |
-
|
| 183 |
-
weights_t2i = F.softmax(sim_t2i,dim=1)
|
| 184 |
-
weights_t2i.masked_fill_(mask, 0)
|
| 185 |
-
|
| 186 |
-
# select a negative image (from same rank) for each text
|
| 187 |
-
image_embeds_neg = []
|
| 188 |
-
for b in range(bs):
|
| 189 |
-
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
| 190 |
-
image_embeds_neg.append(image_embeds[neg_idx])
|
| 191 |
-
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
|
| 192 |
-
|
| 193 |
-
# select a negative text (from same rank) for each image
|
| 194 |
-
text_ids_neg = []
|
| 195 |
-
text_atts_neg = []
|
| 196 |
-
for b in range(bs):
|
| 197 |
-
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
| 198 |
-
text_ids_neg.append(encoder_input_ids[neg_idx])
|
| 199 |
-
text_atts_neg.append(text.attention_mask[neg_idx])
|
| 200 |
-
|
| 201 |
-
text_ids_neg = torch.stack(text_ids_neg,dim=0)
|
| 202 |
-
text_atts_neg = torch.stack(text_atts_neg,dim=0)
|
| 203 |
-
|
| 204 |
-
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
|
| 205 |
-
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
|
| 206 |
-
|
| 207 |
-
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
|
| 208 |
-
image_atts_all = torch.cat([image_atts,image_atts],dim=0)
|
| 209 |
-
|
| 210 |
-
output_neg = self.text_encoder(text_ids_all,
|
| 211 |
-
attention_mask = text_atts_all,
|
| 212 |
-
encoder_hidden_states = image_embeds_all,
|
| 213 |
-
encoder_attention_mask = image_atts_all,
|
| 214 |
-
return_dict = True,
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
|
| 219 |
-
vl_output = self.itm_head(vl_embeddings)
|
| 220 |
-
|
| 221 |
-
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
|
| 222 |
-
dim=0).to(image.device)
|
| 223 |
-
loss_itm = F.cross_entropy(vl_output, itm_labels)
|
| 224 |
-
|
| 225 |
-
return loss_ita, loss_itm
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
@torch.no_grad()
|
| 229 |
-
def copy_params(self):
|
| 230 |
-
for model_pair in self.model_pairs:
|
| 231 |
-
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 232 |
-
param_m.data.copy_(param.data) # initialize
|
| 233 |
-
param_m.requires_grad = False # not update by gradient
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
@torch.no_grad()
|
| 237 |
-
def _momentum_update(self):
|
| 238 |
-
for model_pair in self.model_pairs:
|
| 239 |
-
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
| 240 |
-
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
@torch.no_grad()
|
| 244 |
-
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
|
| 245 |
-
# gather keys before updating queue
|
| 246 |
-
image_feats = concat_all_gather(image_feat)
|
| 247 |
-
text_feats = concat_all_gather(text_feat)
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
batch_size = image_feats.shape[0]
|
| 251 |
-
|
| 252 |
-
ptr = int(self.ptr_queue)
|
| 253 |
-
assert self.queue_size % batch_size == 0 # for simplicity
|
| 254 |
-
|
| 255 |
-
# replace the keys at ptr (dequeue and enqueue)
|
| 256 |
-
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
|
| 257 |
-
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
|
| 258 |
-
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
|
| 259 |
-
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
| 260 |
-
|
| 261 |
-
self.ptr_queue[0] = ptr
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
def blip_retrieval(pretrained='',**kwargs):
|
| 265 |
-
model = BLIP_Retrieval(**kwargs)
|
| 266 |
-
if pretrained:
|
| 267 |
-
model,msg = load_checkpoint(model,pretrained)
|
| 268 |
-
print("missing keys:")
|
| 269 |
-
print(msg.missing_keys)
|
| 270 |
-
return model
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
@torch.no_grad()
|
| 274 |
-
def concat_all_gather(tensor):
|
| 275 |
-
"""
|
| 276 |
-
Performs all_gather operation on the provided tensors.
|
| 277 |
-
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 278 |
-
"""
|
| 279 |
-
tensors_gather = [torch.ones_like(tensor)
|
| 280 |
-
for _ in range(torch.distributed.get_world_size())]
|
| 281 |
-
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
| 282 |
-
|
| 283 |
-
output = torch.cat(tensors_gather, dim=0)
|
| 284 |
-
return output
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
class GatherLayer(torch.autograd.Function):
|
| 288 |
-
"""
|
| 289 |
-
Gather tensors from all workers with support for backward propagation:
|
| 290 |
-
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
| 291 |
-
"""
|
| 292 |
-
|
| 293 |
-
@staticmethod
|
| 294 |
-
def forward(ctx, x):
|
| 295 |
-
output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
|
| 296 |
-
torch.distributed.all_gather(output, x)
|
| 297 |
-
return tuple(output)
|
| 298 |
-
|
| 299 |
-
@staticmethod
|
| 300 |
-
def backward(ctx, *grads):
|
| 301 |
-
all_gradients = torch.stack(grads)
|
| 302 |
-
torch.distributed.all_reduce(all_gradients)
|
| 303 |
-
return all_gradients[torch.distributed.get_rank()]
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def all_gather_with_grad(tensors):
|
| 307 |
-
"""
|
| 308 |
-
Performs all_gather operation on the provided tensors.
|
| 309 |
-
Graph remains connected for backward grad computation.
|
| 310 |
-
"""
|
| 311 |
-
# Queue the gathered tensors
|
| 312 |
-
world_size = torch.distributed.get_world_size()
|
| 313 |
-
# There is no need for reduction in the single-proc case
|
| 314 |
-
if world_size == 1:
|
| 315 |
-
return tensors
|
| 316 |
-
|
| 317 |
-
tensor_all = GatherLayer.apply(tensors)
|
| 318 |
-
|
| 319 |
-
return torch.cat(tensor_all, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|