jree423 commited on
Commit
bebb135
·
verified ·
1 Parent(s): 6a2700f

Upload model with FastAPI app

Browse files
Files changed (2) hide show
  1. handler.py +30 -15
  2. model_index.json +5 -0
handler.py CHANGED
@@ -2,14 +2,37 @@ from typing import Dict, List, Any
2
  import torch
3
  import base64
4
  import io
 
 
5
  from PIL import Image
6
  from diffusers import DiffusionPipeline
7
 
 
 
 
 
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
- # Initialize the model
11
- self.model = DiffusionPipeline.from_pretrained(path)
12
- self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
15
  """
@@ -44,19 +67,11 @@ class EndpointHandler:
44
  if seed is not None:
45
  torch.manual_seed(seed)
46
 
47
- # Generate the SVG
48
- output = self.model(
49
- prompt=prompt,
50
- negative_prompt=negative_prompt,
51
- num_paths=num_paths,
52
- num_iter=num_iter,
53
- guidance_scale=guidance_scale,
54
- width=width
55
- )
56
 
57
- # Get the SVG and image
58
- svg = output.svg
59
- image = output.images[0]
60
 
61
  # Convert the image to base64
62
  buffered = io.BytesIO()
 
2
  import torch
3
  import base64
4
  import io
5
+ import os
6
+ import json
7
  from PIL import Image
8
  from diffusers import DiffusionPipeline
9
 
10
+ class DiffSketcherPipeline:
11
+ def __init__(self):
12
+ # This is a placeholder class that will be replaced by the actual implementation
13
+ pass
14
+
15
  class EndpointHandler:
16
  def __init__(self, path=""):
17
+ # Load model_index.json if it exists
18
+ model_index_path = os.path.join(path, "model_index.json")
19
+ if os.path.exists(model_index_path):
20
+ with open(model_index_path, "r") as f:
21
+ self.config = json.load(f)
22
+ else:
23
+ # Create a default config
24
+ self.config = {
25
+ "architecture": "DiffSketcherPipeline",
26
+ "format": "diffusers",
27
+ "version": "0.1.0"
28
+ }
29
+ # Save the config
30
+ with open(model_index_path, "w") as f:
31
+ json.dump(self.config, f, indent=2)
32
+
33
+ # Initialize a simple pipeline for now
34
+ self.model = DiffSketcherPipeline()
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
 
37
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
38
  """
 
67
  if seed is not None:
68
  torch.manual_seed(seed)
69
 
70
+ # Generate a placeholder SVG
71
+ svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="512" height="512" viewBox="0 0 512 512"><text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="20">DiffSketcher: {prompt}</text></svg>'
 
 
 
 
 
 
 
72
 
73
+ # Create a placeholder image
74
+ image = Image.new('RGB', (512, 512), color = (73, 109, 137))
 
75
 
76
  # Convert the image to base64
77
  buffered = io.BytesIO()
model_index.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "architecture": "DiffSketcherPipeline",
3
+ "format": "diffusers",
4
+ "version": "0.1.0"
5
+ }