jree423 commited on
Commit
2f521ae
·
verified ·
1 Parent(s): dd065be

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +80 -10
handler.py CHANGED
@@ -1,16 +1,86 @@
1
 
2
- from typing import Dict, Any
3
- import torch
4
  import base64
5
  import io
6
- import os
 
 
7
  import json
8
- from PIL import Image
9
- from pipeline import Pipeline
10
 
11
- class EndpointHandler:
12
- def __init__(self, path=""):
13
- self.pipeline = Pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
16
- return self.pipeline(data)
 
 
 
 
 
1
 
2
+ from typing import Dict, List, Any
 
3
  import base64
4
  import io
5
+ from PIL import Image, ImageDraw
6
+ import numpy as np
7
+ import torch
8
  import json
 
 
9
 
10
+ class VectorGraphicsHandler:
11
+ def __init__(self):
12
+ self.initialized = False
13
+
14
+ def initialize(self, context):
15
+ """Initialize the handler."""
16
+ self.initialized = True
17
+
18
+ def preprocess(self, request):
19
+ """Process the input request."""
20
+ inputs = request.pop("inputs", {})
21
+ if isinstance(inputs, str):
22
+ # Single prompt
23
+ prompt = inputs
24
+ payload = {"prompt": prompt}
25
+ else:
26
+ # Full payload
27
+ payload = inputs
28
+
29
+ return payload
30
+
31
+ def inference(self, inputs):
32
+ """Generate vector graphics from the inputs."""
33
+ # This is a placeholder implementation
34
+ # In a real scenario, this would call the actual model
35
+
36
+ # Create a simple SVG based on the prompt
37
+ prompt = inputs.get("prompt", "")
38
+ if not prompt:
39
+ prompts = inputs.get("prompts", [""])
40
+ prompt = prompts[0] if prompts else ""
41
+
42
+ # Generate a simple SVG
43
+ svg = f"""
44
+ <svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512">
45
+ <rect width="512" height="512" fill="#f0f0f0"/>
46
+ <text x="256" y="50" font-family="Arial" font-size="20" text-anchor="middle" fill="#333">Generated from: "{prompt}"</text>
47
+ <g transform="translate(256, 256)">
48
+ <circle cx="0" cy="0" r="100" fill="#3498db" opacity="0.7"/>
49
+ <rect x="-50" y="-50" width="100" height="100" fill="#e74c3c" opacity="0.7"/>
50
+ <path d="M-100,-100 L100,100 M-100,100 L100,-100" stroke="#2c3e50" stroke-width="5"/>
51
+ </g>
52
+ </svg>
53
+ """
54
+
55
+ # Create a simple PNG image
56
+ img = Image.new("RGB", (512, 512), color="#f0f0f0")
57
+ draw = ImageDraw.Draw(img)
58
+ draw.ellipse((156, 156, 356, 356), fill="#3498db", outline="#3498db")
59
+ draw.rectangle((206, 206, 306, 306), fill="#e74c3c", outline="#e74c3c")
60
+ draw.line((156, 156, 356, 356), fill="#2c3e50", width=5)
61
+ draw.line((156, 356, 356, 156), fill="#2c3e50", width=5)
62
+
63
+ # Convert image to base64
64
+ buffered = io.BytesIO()
65
+ img.save(buffered, format="PNG")
66
+ img_str = base64.b64encode(buffered.getvalue()).decode()
67
+
68
+ return {"svg": svg, "image": img_str}
69
+
70
+ def postprocess(self, inference_output):
71
+ """Return the output as JSON."""
72
+ return inference_output
73
+
74
+ _service = VectorGraphicsHandler()
75
+
76
+ def handle(data, context):
77
+ """Handle a request to the model."""
78
+ if not _service.initialized:
79
+ _service.initialize(context)
80
 
81
+ if data is None:
82
+ return None
83
+
84
+ inputs = _service.preprocess(data)
85
+ outputs = _service.inference(inputs)
86
+ return _service.postprocess(outputs)