vinoku89 commited on
Commit
b3a9f3b
·
verified ·
1 Parent(s): 4770c27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -19
app.py CHANGED
@@ -1,45 +1,163 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
  import spaces
 
 
5
 
6
- # Load model
7
- tokenizer = AutoTokenizer.from_pretrained("vinoku89/qwen3-4B-svg-code-gen")
8
- model = AutoModelForCausalLM.from_pretrained("vinoku89/qwen3-4B-svg-code-gen")
9
 
10
- # Move model to GPU if available
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- model.to(device)
13
 
14
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def generate_svg(prompt):
16
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
17
-
18
- outputs = model.generate(
19
- **inputs,
20
- max_length=200,
21
- temperature=0.7,
22
- do_sample=True,
23
- pad_token_id=tokenizer.eos_token_id
 
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
- svg_code = generated_text[len(prompt):].strip()
28
- svg_display = f"<svg xmlns='http://www.w3.org/2000/svg' width='200' height='200'>{svg_code}</svg>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  return svg_code, svg_display
31
 
 
 
 
 
 
 
 
32
  gradio_app = gr.Interface(
33
  fn=generate_svg,
34
- inputs=gr.Textbox(lines=2, placeholder="Describe the SVG you want..."),
 
 
 
35
  outputs=[
36
  gr.Code(label="Generated SVG Code", language="html"),
37
  gr.HTML(label="SVG Preview")
38
  ],
39
  title="SVG Code Generator",
40
- description="Generate SVG code from natural language using a fine-tuned LLM."
 
41
  )
42
 
43
  if __name__ == "__main__":
44
  gradio_app.launch()
 
45
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import gc
5
  import spaces
6
+ import xml.etree.ElementTree as ET
7
+ import re
8
 
9
+ # Clear GPU memory
10
+ torch.cuda.empty_cache()
11
+ gc.collect()
12
 
13
+ # Alpaca prompt template
14
+ alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
15
 
16
+ ### Instruction:
17
+ {}
18
+
19
+ ### Input:
20
+ {}
21
+
22
+ ### Response:
23
+ {}"""
24
+
25
+ # Load model with memory optimizations
26
+ model_path = "vinoku89/qwen3-4B-svg-code-gen"
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_path,
32
+ torch_dtype=torch.float16,
33
+ device_map="auto",
34
+ low_cpu_mem_usage=True,
35
+ trust_remote_code=True # Add this if needed for custom models
36
+ )
37
+
38
+ def validate_svg(svg_content):
39
+ """
40
+ Validate if SVG content is properly formatted and renderable
41
+ """
42
+ try:
43
+ # Clean up the SVG content
44
+ svg_content = svg_content.strip()
45
+
46
+ # If it doesn't start with <svg, try to extract SVG content
47
+ if not svg_content.startswith('<svg'):
48
+ # Look for SVG tags in the content
49
+ svg_match = re.search(r'<svg[^>]*>.*?</svg>', svg_content, re.DOTALL | re.IGNORECASE)
50
+ if svg_match:
51
+ svg_content = svg_match.group(0)
52
+ else:
53
+ # If no complete SVG found, wrap content in SVG tags
54
+ if any(tag in svg_content.lower() for tag in ['<circle', '<rect', '<path', '<line', '<polygon', '<ellipse', '<text']):
55
+ svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" width="250" height="250">{svg_content}</svg>'
56
+ else:
57
+ raise ValueError("No valid SVG elements found")
58
+
59
+ # Parse XML to validate structure
60
+ ET.fromstring(svg_content)
61
+
62
+ return True, svg_content
63
+
64
+ except ET.ParseError as e:
65
+ return False, f"XML Parse Error: {str(e)}"
66
+ except Exception as e:
67
+ return False, f"Validation Error: {str(e)}"
68
+
69
+ @spaces.GPU(duration=60) # Add duration limit
70
  def generate_svg(prompt):
71
+ # Clear cache before generation
72
+ torch.cuda.empty_cache()
73
+
74
+ # Format the prompt using Alpaca template
75
+ instruction = "Generate SVG code based on the given description."
76
+ formatted_prompt = alpaca_prompt.format(
77
+ instruction,
78
+ prompt,
79
+ "" # Empty response - model will fill this
80
  )
81
 
82
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
83
+
84
+ # Move inputs to the same device as model
85
+ if hasattr(model, 'device'):
86
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
87
+
88
+ with torch.no_grad(): # Disable gradient computation to save memory
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_length=1024,
92
+ temperature=0.7,
93
+ do_sample=True,
94
+ pad_token_id=tokenizer.eos_token_id,
95
+ max_new_tokens=512 # Limit new tokens instead of total length
96
+ )
97
+
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
+
100
+ # Extract only the response part (after "### Response:")
101
+ response_start = generated_text.find("### Response:")
102
+ if response_start != -1:
103
+ svg_code = generated_text[response_start + len("### Response:"):].strip()
104
+ else:
105
+ # Fallback: remove the original formatted prompt
106
+ svg_code = generated_text[len(formatted_prompt):].strip()
107
+
108
+ # Validate SVG
109
+ is_valid, result = validate_svg(svg_code)
110
+
111
+ if is_valid:
112
+ # SVG is valid
113
+ validated_svg = result
114
+ # Ensure the SVG has proper dimensions for display (keep moderate size)
115
+ if 'width=' not in validated_svg or 'height=' not in validated_svg:
116
+ validated_svg = validated_svg.replace('<svg', '<svg width="250" height="250"', 1)
117
+ svg_display = validated_svg
118
+ else:
119
+ # SVG is invalid, show error message
120
+ svg_display = f"""
121
+ <div style="width: 250px; height: 200px; border: 2px dashed #ff6b6b;
122
+ display: flex; align-items: center; justify-content: center;
123
+ background-color: #fff5f5; border-radius: 8px; padding: 15px;
124
+ text-align: center; color: #e03131; font-family: Arial, sans-serif;">
125
+ <div>
126
+ <h4 style="margin: 0 0 8px 0; color: #e03131;">🚫 Preview Not Available</h4>
127
+ <p style="margin: 0; font-size: 12px;">Generated SVG contains errors:<br>
128
+ <em style="font-size: 11px;">{result}</em></p>
129
+ </div>
130
+ </div>
131
+ """
132
+
133
+ # Clear cache after generation
134
+ torch.cuda.empty_cache()
135
 
136
  return svg_code, svg_display
137
 
138
+ # Minimal CSS for slightly larger HTML preview only
139
+ custom_css = """
140
+ div[data-testid="HTML"] {
141
+ min-height: 320px !important;
142
+ }
143
+ """
144
+
145
  gradio_app = gr.Interface(
146
  fn=generate_svg,
147
+ inputs=gr.Textbox(
148
+ lines=2,
149
+ placeholder="Describe the SVG you want (e.g., 'a red circle with blue border')..."
150
+ ),
151
  outputs=[
152
  gr.Code(label="Generated SVG Code", language="html"),
153
  gr.HTML(label="SVG Preview")
154
  ],
155
  title="SVG Code Generator",
156
+ description="Generate SVG code from natural language using a fine-tuned LLM.",
157
+ css=custom_css
158
  )
159
 
160
  if __name__ == "__main__":
161
  gradio_app.launch()
162
+
163