Cityscape_Studio / handler.py
Texttra's picture
Update handler.py
bf7ff83 verified
raw
history blame
2 kB
from typing import Dict
import torch
from diffusers import DiffusionPipeline
from compel import Compel
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
print(f"Initializing model from: {path}")
self.pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
use_auth_token=True
)
print("Loading LoRA weights from: Texttra/Cityscape_Studio")
self.pipe.load_lora_weights("Texttra/Cityscape_Studio", weight_name="c1t3_v1.safetensors")
if torch.cuda.is_available():
self.pipe.to("cuda")
else:
self.pipe.to("cpu")
self.pipe.enable_model_cpu_offload()
self.compel = Compel(
tokenizer=self.pipe.tokenizer,
text_encoder=self.pipe.text_encoder
)
print("Model initialized successfully.")
def __call__(self, data: Dict) -> Dict:
print("Received data:", data)
try:
inputs = data.get("inputs", {})
if isinstance(inputs, str):
# In case the input comes in raw string form (e.g., Postman tests)
prompt = inputs
else:
prompt = inputs.get("prompt", "")
print("Extracted prompt:", prompt)
if not prompt:
return {"error": "No prompt provided"}
conditioning = self.compel(prompt)
print("Conditioning complete.")
image = self.pipe(prompt_embeds=conditioning).images[0]
print("Image generated.")
buffer = BytesIO()
image.save(buffer, format="PNG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
print("Returning image.")
return {"image": base64_image}
except Exception as e:
print(f"Error occurred: {str(e)}")
return {"error": str(e)}