File size: 4,034 Bytes
f647f0a
 
427d785
f647f0a
b99cb5c
698cda9
427d785
 
 
 
 
 
 
 
b99cb5c
698cda9
427d785
 
 
 
6cbfe6a
 
427d785
 
 
 
 
 
6cbfe6a
427d785
6cbfe6a
427d785
6cbfe6a
427d785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbfe6a
427d785
 
 
 
6cbfe6a
 
427d785
6cbfe6a
427d785
6cbfe6a
427d785
6cbfe6a
427d785
6cbfe6a
 
427d785
6cbfe6a
 
427d785
6cbfe6a
427d785
 
 
6cbfe6a
 
427d785
6cbfe6a
 
 
 
427d785
6cbfe6a
 
 
427d785
 
6cbfe6a
f647f0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from transformers import pipeline
from PIL import Image
import io, base64, requests, sys, traceback

class EndpointHandler:
    def __init__(self, model_dir: str = "", **kwargs):
        print("πŸ”Ή [INIT] Loading Swin2SR model ...")
        try:
            self.model = pipeline(task="image-to-image", model="sergeipetrov/swin2SR-classical-sr-x2-64")
            print("βœ… [INIT] Model loaded successfully")
        except Exception as e:
            print("❌ [INIT] Model load failed:", str(e))
            traceback.print_exc(file=sys.stdout)
            raise e

    def __call__(self, data):
        print("\n🟒 [CALL] Received request in handler")
        print(f"πŸ”Ή [DEBUG] Raw data type: {type(data)}")
        print(f"πŸ”Ή [DEBUG] Raw data keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}")

        try:
            image_input = data.get("inputs")
            print(f"πŸ”Ή [DEBUG] image_input type: {type(image_input)}")

            # Handle nested dicts (double-wrapped inputs)
            if isinstance(image_input, dict) and "inputs" in image_input:
                print("⚠️ [DEBUG] Nested 'inputs' dict detected β€” unwrapping...")
                image_input = image_input["inputs"]

            # Case 1: URL input
            if isinstance(image_input, str) and image_input.startswith("http"):
                print(f"🌐 [INFO] Fetching image from URL: {image_input[:60]}...")
                image = Image.open(requests.get(image_input, stream=True).raw).convert("RGB")

            # Case 2: Base64-encoded string
            elif isinstance(image_input, str):
                print(f"🧬 [INFO] Detected base64 string (len={len(image_input)})")
                try:
                    image_bytes = base64.b64decode(image_input)
                    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                except Exception as e:
                    print("❌ [ERROR] Base64 decode failed:", str(e))
                    traceback.print_exc(file=sys.stdout)
                    return {"error": f"Failed to decode base64 image: {str(e)}"}

            # Case 3: Raw bytes
            elif isinstance(image_input, (bytes, bytearray)):
                print(f"πŸ“¦ [INFO] Detected raw bytes input (len={len(image_input)})")
                image = Image.open(io.BytesIO(image_input)).convert("RGB")

            else:
                print(f"⚠️ [WARN] Unsupported input type: {type(image_input)}")
                return {"error": f"Unsupported input type: {type(image_input)}"}

            print("βœ… [INFO] Image successfully loaded and converted to RGB")

            # Run inference
            print("πŸš€ [INFER] Running Swin2SR model inference...")
            output = self.model(image)
            print("βœ… [INFER] Inference complete")

            # Normalize output format
            if isinstance(output, (list, tuple)):
                print("πŸ”„ [DEBUG] Output is list/tuple β€” taking first element")
                output = output[0]
            elif isinstance(output, dict) and "image" in output:
                print("πŸ”„ [DEBUG] Output is dict with 'image' key")
                output = output["image"]

            # Validate output type
            if not isinstance(output, Image.Image):
                msg = f"Unexpected model output type: {type(output)}"
                print("❌ [ERROR]", msg)
                return {"error": msg}

            # Encode to base64 for API response
            print("πŸ’Ύ [ENCODE] Encoding result image to base64...")
            buffer = io.BytesIO()
            output.save(buffer, format="PNG")
            encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")

            print("βœ… [RETURN] Returning base64-encoded image")
            return {"image_base64": encoded}

        except Exception as e:
            print("❌ [FATAL] Inference failed with exception:")
            traceback.print_exc(file=sys.stdout)
            return {"error": f"Inference failed: {str(e)}"}