File size: 1,575 Bytes
128e289 002d875 128e289 ad55b48 002d875 128e289 ad55b48 128e289 ad55b48 128e289 ad55b48 128e289 ad55b48 128e289 ad55b48 128e289 ad55b48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | from typing import Dict
import torch
from diffusers import DiffusionPipeline
from compel import Compel
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
# Load base FLUX pipeline
self.pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
variant="fp16",
)
# Load your LoRA weights from the repo
self.pipe.load_lora_weights("./c1t3_v1.safetensors")
# Move to GPU if available
if torch.cuda.is_available():
self.pipe.to("cuda")
else:
self.pipe.to("cpu")
# Optional: enable memory optimization
self.pipe.enable_model_cpu_offload()
# Initialize Compel (prompt parser for FLUX)
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
def __call__(self, data: Dict[str, str]) -> Dict:
# Get prompt from request data
prompt = data.get("prompt", "")
if not prompt:
return {"error": "No prompt provided."}
# Generate prompt conditioning using Compel
conditioning = self.compel(prompt)
# Generate image using FLUX + LoRA
image = self.pipe(prompt_embeds=conditioning).images[0]
# Convert image to base64 string for API response
buffer = BytesIO()
image.save(buffer, format="PNG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
return {"image": base64_image}
|