yanis9351's picture
Update handler.py
40b5f55 verified
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
import os
class EndpointHandler:
def __init__(self, model_dir: str):
print("Initializing model from directory:", model_dir)
# Load model and set to GPU
model_id = "stabilityai/stable-diffusion-x4-upscaler"
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
self.pipeline = self.pipeline.to("cuda")
print("Model loaded successfully.")
def __call__(self, data):
# Check if prompt and image URL are provided in the request
prompt = data.get("prompt", "a generic image")
image_url = data.get("image_url", None)
if image_url is None:
raise ValueError("No image URL provided in the input data.")
# Download and prepare the image
low_res_img = self.download_image(image_url)
if low_res_img is None:
raise ValueError("Failed to load image from provided URL.")
# Resize image to a smaller resolution if needed
low_res_img = low_res_img.resize((128, 128))
# Run upscaling pipeline
upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0]
# Save the upscaled image
output_path = "/tmp/upscaled_image.png"
upscaled_image.save(output_path)
return {"output_image_path": output_path}
def download_image(self, url: str):
try:
response = requests.get(url)
response.raise_for_status()
img = Image.open(BytesIO(response.content)).convert("RGB")
return img
except Exception as e:
print(f"Error downloading image: {e}")
return None