Cityscape_Studio / handler.py
Texttra's picture
Update handler.py
ad55b48 verified
raw
history blame
1.58 kB
from typing import Dict
import torch
from diffusers import DiffusionPipeline
from compel import Compel
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
# Load base FLUX pipeline
self.pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
variant="fp16",
)
# Load your LoRA weights from the repo
self.pipe.load_lora_weights("./c1t3_v1.safetensors")
# Move to GPU if available
if torch.cuda.is_available():
self.pipe.to("cuda")
else:
self.pipe.to("cpu")
# Optional: enable memory optimization
self.pipe.enable_model_cpu_offload()
# Initialize Compel (prompt parser for FLUX)
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
def __call__(self, data: Dict[str, str]) -> Dict:
# Get prompt from request data
prompt = data.get("prompt", "")
if not prompt:
return {"error": "No prompt provided."}
# Generate prompt conditioning using Compel
conditioning = self.compel(prompt)
# Generate image using FLUX + LoRA
image = self.pipe(prompt_embeds=conditioning).images[0]
# Convert image to base64 string for API response
buffer = BytesIO()
image.save(buffer, format="PNG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
return {"image": base64_image}