File size: 5,572 Bytes
ffc93fc f8b22af 8d52022 f8b22af ffc93fc c198636 8d52022 55dc40f 8d52022 55dc40f 8d52022 55dc40f 8d52022 1d1055f 54f01ed 8d52022 1d1055f 8d52022 1d1055f 8d52022 1d1055f 8d52022 1d1055f 54f01ed 8d52022 54f01ed f8b22af d87b721 54f01ed 8d52022 54f01ed 7739491 8d52022 7739491 55dc40f 8d52022 7739491 8d52022 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import io
import sys
import base64
import json
import torch
import numpy as np
from PIL import Image
import requests
import tempfile
import shutil
import subprocess
from pathlib import Path
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting handler initialization")
# Safely import cairosvg with fallback
try:
import cairosvg
debug_log("Successfully imported cairosvg")
except ImportError:
debug_log("cairosvg not found. Installing...")
subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
import cairosvg
debug_log("Installed and imported cairosvg")
class EndpointHandler:
def __init__(self, model_dir):
"""Initialize the handler with model directory"""
try:
debug_log(f"Initializing handler with model_dir: {model_dir}")
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
debug_log(f"Using device: {self.device}")
# Check if model weights exist
weights_path = os.path.join(model_dir, "checkpoint.pth")
if os.path.exists(weights_path):
debug_log(f"Found model weights at {weights_path}")
debug_log(f"Weights file size: {os.path.getsize(weights_path)} bytes")
else:
debug_log(f"Model weights not found at {weights_path}")
# Try to find weights in other locations
for root, dirs, files in os.walk(model_dir):
for file in files:
if file.endswith(".pth"):
debug_log(f"Found weights file: {os.path.join(root, file)}")
# For now, we'll just use a placeholder implementation
self.use_model = False
debug_log("Using placeholder implementation")
except Exception as e:
debug_log(f"Error in handler initialization: {e}")
import traceback
debug_log(traceback.format_exc())
self.use_model = False
def generate_svg(self, prompt, width=512, height=512):
"""Generate an SVG from a text prompt"""
debug_log(f"Generating SVG for prompt: {prompt}")
# Create a more interesting placeholder that looks like a sketch
svg_content = f"""<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg">
<rect width="100%" height="100%" fill="#ffffff"/>
<g stroke="#000000" fill="none">
<!-- Draw a simple sketch based on the prompt -->
<circle cx="{width/2}" cy="{height/2}" r="{min(width, height)/4}" stroke-width="2"/>
<ellipse cx="{width/2}" cy="{height/2}" rx="{width/3}" ry="{height/4}" stroke-width="1.5"/>
<path d="M {width/4} {height/4} Q {width/2} {height/8} {3*width/4} {height/4}" stroke-width="2"/>
<path d="M {width/4} {3*height/4} Q {width/2} {7*height/8} {3*width/4} {3*height/4}" stroke-width="2"/>
</g>
<text x="50%" y="50%" font-family="Arial" font-size="20" text-anchor="middle" fill="#333333">{prompt}</text>
</svg>"""
debug_log("Generated SVG content")
return svg_content
def __call__(self, data):
"""Handle a request to the model"""
try:
debug_log(f"Handling request: {data}")
# Extract the prompt
if isinstance(data, dict) and "inputs" in data:
if isinstance(data["inputs"], str):
prompt = data["inputs"]
elif isinstance(data["inputs"], dict) and "text" in data["inputs"]:
prompt = data["inputs"]["text"]
else:
prompt = "No prompt provided"
else:
prompt = "No prompt provided"
debug_log(f"Extracted prompt: {prompt}")
# Generate SVG
svg_content = self.generate_svg(prompt)
# Convert SVG to PNG
try:
png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
image = Image.open(io.BytesIO(png_data))
debug_log("Generated image from SVG")
except Exception as e:
debug_log(f"Error converting SVG to PNG: {e}")
import traceback
debug_log(traceback.format_exc())
# Create a simple placeholder image
image = Image.new("RGB", (512, 512), color="#f0f0f0")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), prompt, fill="black", anchor="mm")
debug_log("Created placeholder image")
# Return the PIL Image directly
debug_log("Returning image")
return image
except Exception as e:
debug_log(f"Error in handler: {e}")
import traceback
debug_log(traceback.format_exc())
# Return a simple error image
image = Image.new("RGB", (512, 512), color="#ff0000")
from PIL import ImageDraw
draw = ImageDraw.Draw(image)
draw.text((256, 256), f"Error: {str(e)}", fill="white", anchor="mm")
debug_log("Returning error image")
return image
|