Texttra commited on
Commit
17119de
·
verified ·
1 Parent(s): 94b3b96

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -4
handler.py CHANGED
@@ -9,13 +9,23 @@ class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
  print("🚀 Initializing Flux Kontext pipeline...")
11
 
12
- # Load Flux Kontext model
13
  self.pipe = FluxKontextPipeline.from_pretrained(
14
  "black-forest-labs/FLUX.1-Kontext-dev",
15
  torch_dtype=torch.float16,
16
  )
 
 
 
 
 
 
 
 
 
 
17
  self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
18
- print("✅ Model ready.")
19
 
20
  def __call__(self, data: Dict) -> Dict:
21
  print("🔧 Received raw data type:", type(data))
@@ -23,11 +33,11 @@ class EndpointHandler:
23
 
24
  # Defensive parsing
25
  if isinstance(data, dict):
26
- # Some endpoints send data directly as prompt/image dict
27
  prompt = data.get("prompt")
28
  image_input = data.get("image")
29
 
30
- # If 'inputs' key is used (as per HF Inference default schema)
31
  if prompt is None and image_input is None:
32
  inputs = data.get("inputs")
33
  if isinstance(inputs, dict):
 
9
  def __init__(self, path: str = ""):
10
  print("🚀 Initializing Flux Kontext pipeline...")
11
 
12
+ # Load base model from Hugging Face
13
  self.pipe = FluxKontextPipeline.from_pretrained(
14
  "black-forest-labs/FLUX.1-Kontext-dev",
15
  torch_dtype=torch.float16,
16
  )
17
+
18
+ # Load your local LoRA file
19
+ try:
20
+ lora_path = "./Bh0r1.safetensors" # relative path within the container
21
+ self.pipe.load_lora_weights(lora_path)
22
+ print(f"✅ LoRA weights loaded from {lora_path}.")
23
+ except Exception as e:
24
+ print(f"⚠️ Failed to load LoRA weights: {str(e)}")
25
+
26
+ # Move pipeline to GPU if available
27
  self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
28
+ print("✅ Model ready with LoRA applied.")
29
 
30
  def __call__(self, data: Dict) -> Dict:
31
  print("🔧 Received raw data type:", type(data))
 
33
 
34
  # Defensive parsing
35
  if isinstance(data, dict):
36
+ # Direct prompt/image dict
37
  prompt = data.get("prompt")
38
  image_input = data.get("image")
39
 
40
+ # If 'inputs' key is used (HF Inference schema)
41
  if prompt is None and image_input is None:
42
  inputs = data.get("inputs")
43
  if isinstance(inputs, dict):