Cityscape_Studio / handler.py
Texttra's picture
Update handler.py
128e289 verified
raw
history blame
1.36 kB
from typing import Dict
import torch
from diffusers import DiffusionPipeline
from compel import Compel
class EndpointHandler:
def __init__(self, path: str = ""):
# Load base FLUX pipeline
self.pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
variant="fp16",
)
# Load your LoRA weights hosted in the same repo
self.pipe.load_lora_weights("./c1t3_v1.safetensors")
# Move to GPU if available
if torch.cuda.is_available():
self.pipe.to("cuda")
else:
self.pipe.to("cpu")
# Optional: enable memory optimization
self.pipe.enable_model_cpu_offload()
# Initialize Compel (prompt parser)
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
def __call__(self, data: Dict[str, str]) -> Dict:
# Get the prompt from request
prompt = data.get("prompt", "")
if not prompt:
return {"error": "No prompt provided."}
# Process the prompt with Compel (recommended for FLUX)
conditioning = self.compel(prompt)
# Generate the image
image = self.pipe(prompt_embeds=conditioning).images[0]
# Return the result
return {"image": image}