|
|
import os |
|
|
import io |
|
|
import sys |
|
|
import base64 |
|
|
import json |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import requests |
|
|
import tempfile |
|
|
import shutil |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def debug_log(message): |
|
|
print(f"DEBUG: {message}") |
|
|
sys.stdout.flush() |
|
|
|
|
|
debug_log("Starting handler initialization") |
|
|
|
|
|
|
|
|
try: |
|
|
import cairosvg |
|
|
debug_log("Successfully imported cairosvg") |
|
|
except ImportError: |
|
|
debug_log("cairosvg not found. Installing...") |
|
|
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"]) |
|
|
import cairosvg |
|
|
debug_log("Installed and imported cairosvg") |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
"""Initialize the handler with model directory""" |
|
|
try: |
|
|
debug_log(f"Initializing handler with model_dir: {model_dir}") |
|
|
self.model_dir = model_dir |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
debug_log(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
weights_path = os.path.join(model_dir, "checkpoint.pth") |
|
|
if os.path.exists(weights_path): |
|
|
debug_log(f"Found model weights at {weights_path}") |
|
|
debug_log(f"Weights file size: {os.path.getsize(weights_path)} bytes") |
|
|
else: |
|
|
debug_log(f"Model weights not found at {weights_path}") |
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(model_dir): |
|
|
for file in files: |
|
|
if file.endswith(".pth"): |
|
|
debug_log(f"Found weights file: {os.path.join(root, file)}") |
|
|
|
|
|
|
|
|
self.use_model = False |
|
|
debug_log("Using placeholder implementation") |
|
|
except Exception as e: |
|
|
debug_log(f"Error in handler initialization: {e}") |
|
|
import traceback |
|
|
debug_log(traceback.format_exc()) |
|
|
self.use_model = False |
|
|
|
|
|
def generate_svg(self, prompt, width=512, height=512): |
|
|
"""Generate an SVG from a text prompt""" |
|
|
debug_log(f"Generating SVG for prompt: {prompt}") |
|
|
|
|
|
|
|
|
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"> |
|
|
<rect width="100%" height="100%" fill="#ffffff"/> |
|
|
<g stroke="#000000" fill="none"> |
|
|
<!-- Draw a simple sketch based on the prompt --> |
|
|
<circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/> |
|
|
<ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/> |
|
|
<path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/> |
|
|
<path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/> |
|
|
</g> |
|
|
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text> |
|
|
</svg>""" |
|
|
|
|
|
debug_log("Generated SVG content") |
|
|
return svg_content |
|
|
|
|
|
def __call__(self, data): |
|
|
"""Handle a request to the model""" |
|
|
try: |
|
|
debug_log(f"Handling request: {data}") |
|
|
|
|
|
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
|
if isinstance(data["inputs"], str): |
|
|
prompt = data["inputs"] |
|
|
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]: |
|
|
prompt = data["inputs"]["text"] |
|
|
else: |
|
|
prompt = "No prompt provided" |
|
|
else: |
|
|
prompt = "No prompt provided" |
|
|
|
|
|
debug_log(f"Extracted prompt: {prompt}") |
|
|
|
|
|
|
|
|
svg_content = self.generate_svg(prompt) |
|
|
|
|
|
|
|
|
try: |
|
|
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8")) |
|
|
image = Image.open(io.BytesIO(png_data)) |
|
|
debug_log("Generated image from SVG") |
|
|
except Exception as e: |
|
|
debug_log(f"Error converting SVG to PNG: {e}") |
|
|
import traceback |
|
|
debug_log(traceback.format_exc()) |
|
|
|
|
|
image = Image.new("RGB", (512, 512), color="#f0f0f0") |
|
|
from PIL import ImageDraw |
|
|
draw = ImageDraw.Draw(image) |
|
|
draw.text((256, 256), prompt, fill="black", anchor="mm") |
|
|
debug_log("Created placeholder image") |
|
|
|
|
|
|
|
|
debug_log("Returning image") |
|
|
return image |
|
|
except Exception as e: |
|
|
debug_log(f"Error in handler: {e}") |
|
|
import traceback |
|
|
debug_log(traceback.format_exc()) |
|
|
|
|
|
image = Image.new("RGB", (512, 512), color="#ff0000") |
|
|
from PIL import ImageDraw |
|
|
draw = ImageDraw.Draw(image) |
|
|
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm") |
|
|
debug_log("Returning error image") |
|
|
return image |
|
|
|