Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,21 +29,27 @@ os.environ['TMPDIR'] = path_dict['gradio_temp_dir']
|
|
| 29 |
|
| 30 |
|
| 31 |
# Function of loading pre-trained models.
|
| 32 |
-
|
| 33 |
-
base_model_path = snapshot_download("black-forest-labs/FLUX.1-Fill-dev")
|
| 34 |
-
image_encoder_path = snapshot_download("google/siglip-so400m-patch14-384")
|
| 35 |
-
calligrapher_path = snapshot_download("Calligrapher2025/Calligrapher")
|
| 36 |
-
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder="transformer",
|
| 37 |
-
|
| 38 |
-
pipe = CalligrapherPipeline.from_pretrained(base_model_path, transformer=transformer,
|
| 39 |
-
|
| 40 |
-
model = Calligrapher(pipe, image_encoder_path, calligrapher_path, device="cuda", num_tokens=128)
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
@spaces.GPU
|
| 44 |
def process_and_generate(editor_component, reference_image, prompt, height, width,
|
| 45 |
scale, steps=50, seed=42, use_context=True, num_images=1):
|
| 46 |
print('Begin processing!')
|
|
|
|
| 47 |
# Job directory.
|
| 48 |
job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
| 49 |
job_dir = os.path.join(SAVE_DIR, job_name)
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
# Function of loading pre-trained models.
|
| 32 |
+
def load_models():
|
| 33 |
+
base_model_path = snapshot_download("black-forest-labs/FLUX.1-Fill-dev")
|
| 34 |
+
image_encoder_path = snapshot_download("google/siglip-so400m-patch14-384")
|
| 35 |
+
calligrapher_path = snapshot_download("Calligrapher2025/Calligrapher")
|
| 36 |
+
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder="transformer",
|
| 37 |
+
torch_dtype=torch.bfloat16)
|
| 38 |
+
pipe = CalligrapherPipeline.from_pretrained(base_model_path, transformer=transformer,
|
| 39 |
+
torch_dtype=torch.bfloat16).to("cuda")
|
| 40 |
+
model = Calligrapher(pipe, image_encoder_path, calligrapher_path, device="cuda", num_tokens=128)
|
| 41 |
+
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
@gr.CachedFunction()
|
| 45 |
+
def load_model_cached():
|
| 46 |
+
return load_models()
|
| 47 |
|
| 48 |
@spaces.GPU
|
| 49 |
def process_and_generate(editor_component, reference_image, prompt, height, width,
|
| 50 |
scale, steps=50, seed=42, use_context=True, num_images=1):
|
| 51 |
print('Begin processing!')
|
| 52 |
+
model = load_model_cached()
|
| 53 |
# Job directory.
|
| 54 |
job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
| 55 |
job_dir = os.path.join(SAVE_DIR, job_name)
|