| |
| |
|
|
| import os |
| import sys |
| import torch |
| import numpy as np |
| from PIL import Image |
| import io |
| import base64 |
| from handler_template import BaseHandler |
|
|
| |
| sys.path.append("/app/model") |
|
|
| class Handler(BaseHandler): |
| def initialize(self): |
| """Load the DiffSketcher model""" |
| try: |
| from models.clip_text_encoder import CLIPTextEncoder |
| from models.sketch_generator import SketchGenerator |
| |
| |
| self.text_encoder = CLIPTextEncoder() |
| self.text_encoder.to(self.device) |
| self.text_encoder.eval() |
| |
| |
| self.model = SketchGenerator() |
| weights_path = os.path.join("/app/model/weights", "diffsketcher_model.pth") |
| if os.path.exists(weights_path): |
| state_dict = torch.load(weights_path, map_location=self.device) |
| self.model.load_state_dict(state_dict) |
| else: |
| raise FileNotFoundError(f"Model weights not found at {weights_path}") |
| |
| self.model.to(self.device) |
| self.model.eval() |
| |
| self.initialized = True |
| print("DiffSketcher model initialized successfully") |
| except Exception as e: |
| print(f"Error initializing DiffSketcher model: {str(e)}") |
| raise |
| |
| def preprocess(self, data): |
| """Process the input data""" |
| try: |
| |
| prompt = data.get("prompt", "") |
| if not prompt: |
| raise ValueError("No prompt provided in the request") |
| |
| |
| with torch.no_grad(): |
| text_embedding = self.text_encoder.encode_text(prompt) |
| |
| return { |
| "text_embedding": text_embedding, |
| "prompt": prompt |
| } |
| except Exception as e: |
| print(f"Error in preprocessing: {str(e)}") |
| raise |
| |
| def inference(self, inputs): |
| """Generate SVG from text embedding""" |
| try: |
| text_embedding = inputs["text_embedding"] |
| |
| |
| with torch.no_grad(): |
| svg_data = self.model.generate(text_embedding) |
| |
| return svg_data |
| except Exception as e: |
| print(f"Error during inference: {str(e)}") |
| raise |
| |
| def postprocess(self, inference_output): |
| """Format the model output""" |
| try: |
| svg_content = inference_output["svg_content"] |
| |
| |
| return { |
| "svg_content": svg_content, |
| "svg_base64": self.svg_to_base64(svg_content) |
| } |
| except Exception as e: |
| print(f"Error in postprocessing: {str(e)}") |
| return {"error": str(e)} |