| import os |
| import sys |
| import json |
| import torch |
| from pathlib import Path |
|
|
| |
| def get_model_type(): |
| |
| model_type = "diffsketcher" |
| |
| |
| if os.path.exists("/repository"): |
| repo_path = Path("/repository") |
| |
| if os.path.exists("/repository/.git"): |
| try: |
| with open("/repository/.git/config", "r") as f: |
| config = f.read() |
| if "svgdreamer" in config.lower(): |
| model_type = "svgdreamer" |
| elif "diffsketcher_edit" in config.lower() or "diffsketcher-edit" in config.lower(): |
| model_type = "diffsketcher_edit" |
| except: |
| pass |
| |
| print(f"Detected model type: {model_type}") |
| return model_type |
|
|
| |
| def import_handler(): |
| model_type = get_model_type() |
| |
| if model_type == "svgdreamer": |
| from svgdreamer_handler import SVGDreamerHandler |
| return SVGDreamerHandler() |
| elif model_type == "diffsketcher_edit": |
| from diffsketcher_edit_handler import DiffSketcherEditHandler |
| return DiffSketcherEditHandler() |
| else: |
| from diffsketcher_handler import DiffSketcherHandler |
| return DiffSketcherHandler() |
|
|
| |
| handler = import_handler() |
| handler.initialize(None) |
|
|
| |
| def inference(model_inputs): |
| global handler |
| return handler.handle(model_inputs, None) |
|
|
| |
| if __name__ == "__main__": |
| |
| sample_input = { |
| "inputs": "a beautiful mountain landscape", |
| "parameters": {} |
| } |
| |
| result = inference(sample_input) |
| print(f"Generated SVG with {len(result['svg'])} characters") |
| |
| |
| with open("output.svg", "w") as f: |
| f.write(result["svg"]) |
| |
| print("SVG saved to output.svg") |