|
|
import requests |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from diffusers import StableDiffusionUpscalePipeline |
|
|
import torch |
|
|
import os |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str): |
|
|
print("Initializing model from directory:", model_dir) |
|
|
|
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-x4-upscaler" |
|
|
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
|
|
self.pipeline = self.pipeline.to("cuda") |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
prompt = data.get("prompt", "a generic image") |
|
|
image_url = data.get("image_url", None) |
|
|
|
|
|
if image_url is None: |
|
|
raise ValueError("No image URL provided in the input data.") |
|
|
|
|
|
|
|
|
low_res_img = self.download_image(image_url) |
|
|
|
|
|
if low_res_img is None: |
|
|
raise ValueError("Failed to load image from provided URL.") |
|
|
|
|
|
|
|
|
low_res_img = low_res_img.resize((128, 128)) |
|
|
|
|
|
|
|
|
upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0] |
|
|
|
|
|
|
|
|
output_path = "/tmp/upscaled_image.png" |
|
|
upscaled_image.save(output_path) |
|
|
|
|
|
return {"output_image_path": output_path} |
|
|
|
|
|
def download_image(self, url: str): |
|
|
try: |
|
|
response = requests.get(url) |
|
|
response.raise_for_status() |
|
|
img = Image.open(BytesIO(response.content)).convert("RGB") |
|
|
return img |
|
|
except Exception as e: |
|
|
print(f"Error downloading image: {e}") |
|
|
return None |