Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from starvector.data.util import process_and_rasterize_svg | |
| # Load model and processor | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "starvector/starvector-8b-im2svg", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).cuda() | |
| processor = AutoProcessor.from_pretrained("starvector/starvector-8b-im2svg") | |
| def generate_svg(input_data, input_type): | |
| if input_type == "image": | |
| # Process image input | |
| image = processor(input_data, return_tensors="pt")['pixel_values'].cuda() | |
| raw_svg = model.generate_im2svg({"image": image}, max_length=4000)[0] | |
| else: | |
| # Process text input | |
| raw_svg = model.generate_text2svg(input_data, max_length=4000)[0] | |
| svg_code, raster_image = process_and_rasterize_svg(raw_svg) | |
| return svg_code, raster_image | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π« StarVector SVG Generator") | |
| with gr.Tab("Image to SVG"): | |
| gr.Markdown("Upload an image to convert to SVG") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| image_output = gr.Image(label="SVG Preview") | |
| svg_code = gr.Code(label="Generated SVG Code") | |
| image_button = gr.Button("Convert to SVG") | |
| with gr.Tab("Text to SVG"): | |
| gr.Markdown("Enter text to generate SVG") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Text Prompt") | |
| text_output = gr.Image(label="SVG Preview") | |
| text_svg_code = gr.Code(label="Generated SVG Code") | |
| text_button = gr.Button("Generate SVG") | |
| image_button.click( | |
| fn=lambda x: generate_svg(x, "image"), | |
| inputs=image_input, | |
| outputs=[svg_code, image_output] | |
| ) | |
| text_button.click( | |
| fn=lambda x: generate_svg(x, "text"), | |
| inputs=text_input, | |
| outputs=[text_svg_code, text_output] | |
| ) | |
| demo.launch() | |