SahilCarterr commited on
Commit
62af58d
·
verified ·
1 Parent(s): ebb93d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -29,21 +29,27 @@ os.environ['TMPDIR'] = path_dict['gradio_temp_dir']
29
 
30
 
31
  # Function of loading pre-trained models.
32
-
33
- base_model_path = snapshot_download("black-forest-labs/FLUX.1-Fill-dev")
34
- image_encoder_path = snapshot_download("google/siglip-so400m-patch14-384")
35
- calligrapher_path = snapshot_download("Calligrapher2025/Calligrapher")
36
- transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder="transformer",
37
- torch_dtype=torch.bfloat16)
38
- pipe = CalligrapherPipeline.from_pretrained(base_model_path, transformer=transformer,
39
- torch_dtype=torch.bfloat16).to("cuda")
40
- model = Calligrapher(pipe, image_encoder_path, calligrapher_path, device="cuda", num_tokens=128)
41
-
 
 
 
 
 
42
 
43
  @spaces.GPU
44
  def process_and_generate(editor_component, reference_image, prompt, height, width,
45
  scale, steps=50, seed=42, use_context=True, num_images=1):
46
  print('Begin processing!')
 
47
  # Job directory.
48
  job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
49
  job_dir = os.path.join(SAVE_DIR, job_name)
 
29
 
30
 
31
  # Function of loading pre-trained models.
32
+ def load_models():
33
+ base_model_path = snapshot_download("black-forest-labs/FLUX.1-Fill-dev")
34
+ image_encoder_path = snapshot_download("google/siglip-so400m-patch14-384")
35
+ calligrapher_path = snapshot_download("Calligrapher2025/Calligrapher")
36
+ transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder="transformer",
37
+ torch_dtype=torch.bfloat16)
38
+ pipe = CalligrapherPipeline.from_pretrained(base_model_path, transformer=transformer,
39
+ torch_dtype=torch.bfloat16).to("cuda")
40
+ model = Calligrapher(pipe, image_encoder_path, calligrapher_path, device="cuda", num_tokens=128)
41
+
42
+ return model
43
+
44
+ @gr.CachedFunction()
45
+ def load_model_cached():
46
+ return load_models()
47
 
48
  @spaces.GPU
49
  def process_and_generate(editor_component, reference_image, prompt, height, width,
50
  scale, steps=50, seed=42, use_context=True, num_images=1):
51
  print('Begin processing!')
52
+ model = load_model_cached()
53
  # Job directory.
54
  job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
55
  job_dir = os.path.join(SAVE_DIR, job_name)