File size: 1,871 Bytes
a370950
 
 
d96c5fb
 
a370950
d96c5fb
40b5f55
 
 
 
 
 
 
 
 
d96c5fb
40b5f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a370950
 
40b5f55
 
a370950
40b5f55
 
 
 
 
d96c5fb
40b5f55
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
import os

class EndpointHandler:
    def __init__(self, model_dir: str):
        print("Initializing model from directory:", model_dir)
        
        # Load model and set to GPU
        model_id = "stabilityai/stable-diffusion-x4-upscaler"
        self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.pipeline = self.pipeline.to("cuda")
        print("Model loaded successfully.")

    def __call__(self, data):
        # Check if prompt and image URL are provided in the request
        prompt = data.get("prompt", "a generic image")
        image_url = data.get("image_url", None)
        
        if image_url is None:
            raise ValueError("No image URL provided in the input data.")
        
        # Download and prepare the image
        low_res_img = self.download_image(image_url)
        
        if low_res_img is None:
            raise ValueError("Failed to load image from provided URL.")
        
        # Resize image to a smaller resolution if needed
        low_res_img = low_res_img.resize((128, 128))
        
        # Run upscaling pipeline
        upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0]
        
        # Save the upscaled image
        output_path = "/tmp/upscaled_image.png"
        upscaled_image.save(output_path)
        
        return {"output_image_path": output_path}

    def download_image(self, url: str):
        try:
            response = requests.get(url)
            response.raise_for_status()
            img = Image.open(BytesIO(response.content)).convert("RGB")
            return img
        except Exception as e:
            print(f"Error downloading image: {e}")
            return None