jree423 commited on
Commit
ea933c9
·
verified ·
1 Parent(s): 6d2a86e

Update: Add original model implementation

Browse files
Files changed (2) hide show
  1. Dockerfile +15 -3
  2. handler.py +11 -20
Dockerfile CHANGED
@@ -17,20 +17,32 @@ RUN pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://downl
17
  # Install CLIP
18
  RUN pip install git+https://github.com/openai/CLIP.git
19
 
 
 
 
20
  # Install cairosvg and other dependencies
21
  RUN pip install cairosvg cairocffi cssselect2 defusedxml tinycss2
22
 
23
  # Install FastAPI and other dependencies
24
  RUN pip install fastapi uvicorn pydantic pillow numpy requests
25
 
 
 
 
26
  # Copy the model files
27
  COPY . /code/
28
 
 
 
 
 
 
 
29
  # Make sure the handler and model are available
30
- RUN if [ -f /code/simplified_diffsketcher.py ]; then \
31
- echo "Simplified DiffSketcher found"; \
32
  else \
33
- echo "Simplified DiffSketcher not found, using placeholder"; \
34
  fi
35
 
36
  # Set environment variables
 
17
  # Install CLIP
18
  RUN pip install git+https://github.com/openai/CLIP.git
19
 
20
+ # Install diffusers and other dependencies
21
+ RUN pip install diffusers transformers accelerate xformers omegaconf einops kornia
22
+
23
  # Install cairosvg and other dependencies
24
  RUN pip install cairosvg cairocffi cssselect2 defusedxml tinycss2
25
 
26
  # Install FastAPI and other dependencies
27
  RUN pip install fastapi uvicorn pydantic pillow numpy requests
28
 
29
+ # Install SVG dependencies
30
+ RUN pip install svgwrite svgpathtools cssutils numba
31
+
32
  # Copy the model files
33
  COPY . /code/
34
 
35
+ # Download model weights if they don't exist
36
+ RUN if [ ! -f /code/ViT-B-32.pt ]; then \
37
+ pip install gdown && \
38
+ python -c "import clip; clip.load('ViT-B-32')" ; \
39
+ fi
40
+
41
  # Make sure the handler and model are available
42
+ RUN if [ -f /code/diffsketcher_endpoint.py ]; then \
43
+ echo "DiffSketcher endpoint found"; \
44
  else \
45
+ echo "DiffSketcher endpoint not found, using placeholder"; \
46
  fi
47
 
48
  # Set environment variables
handler.py CHANGED
@@ -15,21 +15,12 @@ except ImportError:
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
- # Safely import clip with fallback
19
  try:
20
- import clip
21
  except ImportError:
22
- print("Warning: clip not found. Installing...")
23
- import subprocess
24
- subprocess.check_call(["pip", "install", "git+https://github.com/openai/CLIP.git"])
25
- import clip
26
-
27
- # Import the versatile SVG generator
28
- try:
29
- from versatile_svg_generator import VersatileSVGGenerator
30
- except ImportError:
31
- print("Warning: versatile_svg_generator not found. Using placeholder.")
32
- VersatileSVGGenerator = None
33
 
34
  class EndpointHandler:
35
  def __init__(self, model_dir):
@@ -38,14 +29,14 @@ class EndpointHandler:
38
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
  print(f"Initializing model on device: {self.device}")
40
 
41
- # Initialize the versatile SVG generator if available
42
- if VersatileSVGGenerator is not None:
43
  try:
44
- self.model = VersatileSVGGenerator(model_dir)
45
  self.use_model = True
46
- print("Versatile SVG generator initialized successfully")
47
  except Exception as e:
48
- print(f"Error initializing versatile SVG generator: {e}")
49
  self.use_model = False
50
  else:
51
  self.use_model = False
@@ -76,11 +67,11 @@ class EndpointHandler:
76
  # Generate SVG using the model or placeholder
77
  if self.use_model:
78
  try:
79
- # Use the versatile SVG generator
80
  result = self.model(prompt)
81
  image = result["image"]
82
  except Exception as e:
83
- print(f"Error using versatile SVG generator: {e}")
84
  # Fall back to placeholder
85
  svg_content = self.generate_placeholder_svg(prompt)
86
  png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))
 
15
  subprocess.check_call(["pip", "install", "cairosvg", "cairocffi", "cssselect2", "defusedxml", "tinycss2"])
16
  import cairosvg
17
 
18
+ # Import the DiffSketcher endpoint
19
  try:
20
+ from diffsketcher_endpoint import DiffSketcherEndpoint
21
  except ImportError:
22
+ print("Warning: diffsketcher_endpoint not found. Using placeholder.")
23
+ DiffSketcherEndpoint = None
 
 
 
 
 
 
 
 
 
24
 
25
  class EndpointHandler:
26
  def __init__(self, model_dir):
 
29
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  print(f"Initializing model on device: {self.device}")
31
 
32
+ # Initialize the DiffSketcher endpoint if available
33
+ if DiffSketcherEndpoint is not None:
34
  try:
35
+ self.model = DiffSketcherEndpoint(model_dir)
36
  self.use_model = True
37
+ print("DiffSketcher endpoint initialized successfully")
38
  except Exception as e:
39
+ print(f"Error initializing DiffSketcher endpoint: {e}")
40
  self.use_model = False
41
  else:
42
  self.use_model = False
 
67
  # Generate SVG using the model or placeholder
68
  if self.use_model:
69
  try:
70
+ # Use the DiffSketcher endpoint
71
  result = self.model(prompt)
72
  image = result["image"]
73
  except Exception as e:
74
+ print(f"Error using DiffSketcher endpoint: {e}")
75
  # Fall back to placeholder
76
  svg_content = self.generate_placeholder_svg(prompt)
77
  png_data = cairosvg.svg2png(bytestring=svg_content.encode("utf-8"))