File size: 5,572 Bytes
ffc93fc
f8b22af
8d52022
f8b22af
 
ffc93fc
 
c198636
8d52022
 
 
 
 
 
 
 
 
 
 
 
55dc40f
 
 
 
8d52022
55dc40f
8d52022
55dc40f
 
8d52022
1d1055f
54f01ed
 
 
8d52022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d1055f
 
8d52022
 
 
 
 
1d1055f
8d52022
 
 
 
 
 
 
 
 
1d1055f
8d52022
 
1d1055f
54f01ed
 
 
 
8d52022
 
54f01ed
 
 
 
 
 
 
 
f8b22af
d87b721
54f01ed
8d52022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54f01ed
7739491
8d52022
7739491
55dc40f
8d52022
 
 
7739491
 
 
 
 
8d52022
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import io
import sys
import base64
import json
import torch
import numpy as np
from PIL import Image
import requests
import tempfile
import shutil
import subprocess
from pathlib import Path

# Add debug logging
def debug_log(message):
    print(f"DEBUG: {message}")
    sys.stdout.flush()

debug_log("Starting handler initialization")

# Safely import cairosvg with fallback
try:
    import cairosvg
    debug_log("Successfully imported cairosvg")
except ImportError:
    debug_log("cairosvg not found. Installing...")
    subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
    import cairosvg
    debug_log("Installed and imported cairosvg")

class EndpointHandler:
    def __init__(self, model_dir):
        """Initialize the handler with model directory"""
        try:
            debug_log(f"Initializing handler with model_dir: {model_dir}")
            self.model_dir = model_dir
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            debug_log(f"Using device: {self.device}")
            
            # Check if model weights exist
            weights_path = os.path.join(model_dir, "checkpoint.pth")
            if os.path.exists(weights_path):
                debug_log(f"Found model weights at {weights_path}")
                debug_log(f"Weights file size: {os.path.getsize(weights_path)} bytes")
            else:
                debug_log(f"Model weights not found at {weights_path}")
                
                # Try to find weights in other locations
                for root, dirs, files in os.walk(model_dir):
                    for file in files:
                        if file.endswith(".pth"):
                            debug_log(f"Found weights file: {os.path.join(root, file)}")
            
            # For now, we'll just use a placeholder implementation
            self.use_model = False
            debug_log("Using placeholder implementation")
        except Exception as e:
            debug_log(f"Error in handler initialization: {e}")
            import traceback
            debug_log(traceback.format_exc())
            self.use_model = False
    
    def generate_svg(self, prompt, width=512, height=512):
        """Generate an SVG from a text prompt"""
        debug_log(f"Generating SVG for prompt: {prompt}")
        
        # Create a more interesting placeholder that looks like a sketch
        svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
            <rect width="100%" height="100%" fill="#ffffff"/>
            <g stroke="#000000" fill="none">
                <!-- Draw a simple sketch based on the prompt -->
                <circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/>
                <ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/>
                <path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/>
                <path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/>
            </g>
            <text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
        </svg>"""
        
        debug_log("Generated SVG content")
        return svg_content
    
    def __call__(self, data):
        """Handle a request to the model"""
        try:
            debug_log(f"Handling request: {data}")
            
            # Extract the prompt
            if isinstance(data, dict) and "inputs" in data:
                if isinstance(data["inputs"], str):
                    prompt = data["inputs"]
                elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
                    prompt = data["inputs"]["text"]
                else:
                    prompt = "No prompt provided"
            else:
                prompt = "No prompt provided"
            
            debug_log(f"Extracted prompt: {prompt}")
            
            # Generate SVG
            svg_content = self.generate_svg(prompt)
            
            # Convert SVG to PNG
            try:
                png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
                image = Image.open(io.BytesIO(png_data))
                debug_log("Generated image from SVG")
            except Exception as e:
                debug_log(f"Error converting SVG to PNG: {e}")
                import traceback
                debug_log(traceback.format_exc())
                # Create a simple placeholder image
                image = Image.new("RGB", (512, 512), color="#f0f0f0")
                from PIL import ImageDraw
                draw = ImageDraw.Draw(image)
                draw.text((256, 256), prompt, fill="black", anchor="mm")
                debug_log("Created placeholder image")
            
            # Return the PIL Image directly
            debug_log("Returning image")
            return image
        except Exception as e:
            debug_log(f"Error in handler: {e}")
            import traceback
            debug_log(traceback.format_exc())
            # Return a simple error image
            image = Image.new("RGB", (512, 512), color="#ff0000")
            from PIL import ImageDraw
            draw = ImageDraw.Draw(image)
            draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
            debug_log("Returning error image")
            return image