jree423 commited on
Commit
5560bac
·
verified ·
1 Parent(s): a2fd1ce

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +109 -0
handler.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import sys
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ import traceback
8
+ import json
9
+ import logging
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Add the model directory to the path
17
+ sys.path.append('/code/diffsketcher_edit')
18
+
19
+ # Safely import cairosvg with fallback
20
+ try:
21
+ import cairosvg
22
+ logger.info("Successfully imported cairosvg")
23
+ except ImportError:
24
+ logger.warning("cairosvg not found. Installing...")
25
+ import subprocess
26
+ subprocess.check_call(["pip", "install", "cairosvg"])
27
+ import cairosvg
28
+ logger.info("Successfully installed and imported cairosvg")
29
+
30
+ class EndpointHandler:
31
+ def __init__(self, model_dir):
32
+ # Initialize the handler with model directory
33
+ logger.info(f"Initializing handler with model_dir: {model_dir}")
34
+ self.model_dir = model_dir
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ logger.info(f"Using device: {self.device}")
37
+
38
+ # Initialize the model
39
+ logger.info("Initializing diffsketcher_edit model...")
40
+ self.model = self._initialize_model()
41
+ logger.info("diffsketcher_edit model initialized")
42
+
43
+ def _initialize_model(self):
44
+ # Initialize the diffsketcher_edit model
45
+ # This is a placeholder for the actual model initialization
46
+ # In a real implementation, you would load the model weights and initialize the model
47
+ return None
48
+
49
+ def generate_svg(self, prompt, width=512, height=512, num_paths=512, seed=None):
50
+ # Generate an SVG from a text prompt
51
+ logger.info(f"Generating SVG for prompt: {prompt}")
52
+
53
+ # This is a placeholder for the actual SVG generation
54
+ # In a real implementation, you would use the model to generate an SVG
55
+ svg_content = f'<svg width="{width}" height="{height}" xmlns="http://www.w3.org/2000/svg"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">{prompt}</text></svg>'
56
+
57
+ return svg_content
58
+
59
+ def __call__(self, data):
60
+ # Handle a request to the model
61
+ try:
62
+ logger.info(f"Handling request with data: {data}")
63
+
64
+ # Extract the prompt and parameters
65
+ if isinstance(data, dict):
66
+ if "inputs" in data:
67
+ if isinstance(data["inputs"], str):
68
+ prompt = data["inputs"]
69
+ params = {}
70
+ elif isinstance(data["inputs"], dict):
71
+ prompt = data["inputs"].get("text", "No prompt provided")
72
+ params = {k: v for k, v in data["inputs"].items() if k != "text"}
73
+ else:
74
+ prompt = "No prompt provided"
75
+ params = {}
76
+ else:
77
+ prompt = "No prompt provided"
78
+ params = {}
79
+ else:
80
+ prompt = "No prompt provided"
81
+ params = {}
82
+
83
+ logger.info(f"Extracted prompt: {prompt}")
84
+ logger.info(f"Extracted parameters: {params}")
85
+
86
+ # Extract parameters
87
+ width = params.get("width", 512)
88
+ height = params.get("height", 512)
89
+ num_paths = params.get("num_paths", 512)
90
+ seed = params.get("seed", None)
91
+
92
+ # Generate SVG
93
+ svg_content = self.generate_svg(prompt, width, height, num_paths, seed)
94
+ logger.info("SVG content generated")
95
+
96
+ # Convert SVG to PNG
97
+ logger.info("Converting SVG to PNG")
98
+ png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
99
+ image = Image.open(io.BytesIO(png_data))
100
+ logger.info(f"Converted to PNG with size: {image.size}")
101
+
102
+ # Return the PIL Image directly
103
+ return image
104
+ except Exception as e:
105
+ logger.error(f"Error in handler: {e}")
106
+ logger.error(traceback.format_exc())
107
+ # Return an error image
108
+ error_image = Image.new('RGB', (512, 512), color='red')
109
+ return error_image