vinoku89 commited on
Commit
b8e846c
·
verified ·
1 Parent(s): efababe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -112
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
  import gc
5
  import spaces
@@ -11,173 +12,260 @@ import os
11
  torch.cuda.empty_cache()
12
  gc.collect()
13
 
14
- # Alpaca prompt template
15
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
 
 
16
 
17
- ### Instruction:
18
- {}
 
 
19
 
20
- ### Input:
21
- {}
 
22
 
23
- ### Response:
24
- {}"""
25
-
26
- # Load model with memory optimizations
27
- model_path = "vinoku89/qwen3-4B-svg-code-gen"
28
-
29
- tokenizer = AutoTokenizer.from_pretrained(model_path)
30
-
31
- model = AutoModelForCausalLM.from_pretrained(
32
- model_path,
33
  torch_dtype=torch.float16,
34
  device_map="auto",
35
  low_cpu_mem_usage=True,
36
- trust_remote_code=True # Add this if needed for custom models
 
 
 
 
 
 
 
37
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- def validate_svg(svg_content):
 
 
 
 
 
40
  """
41
- Validate if SVG content is properly formatted and renderable
42
  """
43
  try:
44
- # Clean up the SVG content
45
  svg_content = svg_content.strip()
46
-
47
  # If it doesn't start with <svg, try to extract SVG content
48
  if not svg_content.startswith('<svg'):
49
- # Look for SVG tags in the content
50
- svg_match = re.search(r'<svg[^>]*>.*?</svg>', svg_content, re.DOTALL | re.IGNORECASE)
51
- if svg_match:
52
- svg_content = svg_match.group(0)
53
- else:
54
- # If no complete SVG found, wrap content in SVG tags
55
- if any(tag in svg_content.lower() for tag in ['<circle', '<rect', '<path', '<line', '<polygon', '<ellipse', '<text']):
56
- svg_content = f'<svg xmlns="http://www.w3.org/2000/svg" width="250" height="250">{svg_content}</svg>'
57
- else:
58
- raise ValueError("No valid SVG elements found")
59
-
60
  # Parse XML to validate structure
61
  ET.fromstring(svg_content)
62
-
63
  return True, svg_content
64
-
65
  except ET.ParseError as e:
66
  return False, f"XML Parse Error: {str(e)}"
67
  except Exception as e:
68
  return False, f"Validation Error: {str(e)}"
69
 
70
- @spaces.GPU(duration=60) # Add duration limit
71
- def generate_svg(prompt):
 
 
 
 
72
  # Clear cache before generation
73
  torch.cuda.empty_cache()
74
-
75
- # Format the prompt using Alpaca template
76
- instruction = "Generate SVG code based on the given description."
77
- formatted_prompt = alpaca_prompt.format(
78
- instruction,
79
- prompt,
80
- "" # Empty response - model will fill this
81
- )
82
-
83
- inputs = tokenizer(formatted_prompt, return_tensors="pt")
84
-
85
- # Move inputs to the same device as model
86
  if hasattr(model, 'device'):
87
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
88
-
89
- with torch.no_grad(): # Disable gradient computation to save memory
 
 
 
90
  outputs = model.generate(
91
  **inputs,
92
- max_length=1024,
93
- temperature=0.7,
94
  do_sample=True,
 
95
  pad_token_id=tokenizer.eos_token_id,
96
- max_new_tokens=512 # Limit new tokens instead of total length
97
  )
98
-
 
99
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
-
101
- # Extract only the response part (after "### Response:")
102
- response_start = generated_text.find("### Response:")
103
- if response_start != -1:
104
- svg_code = generated_text[response_start + len("### Response:"):].strip()
 
105
  else:
106
- # Fallback: remove the original formatted prompt
107
- svg_code = generated_text[len(formatted_prompt):].strip()
108
-
 
 
 
109
  # Validate SVG
110
  is_valid, result = validate_svg(svg_code)
111
-
112
  if is_valid:
113
- # SVG is valid
114
  validated_svg = result
115
- # Ensure the SVG has proper dimensions for display (keep moderate size)
116
- if 'width=' not in validated_svg or 'height=' not in validated_svg:
117
- validated_svg = validated_svg.replace('<svg', '<svg width="250" height="250"', 1)
118
  svg_display = validated_svg
119
  else:
120
- # SVG is invalid, show error message
121
  svg_display = f"""
122
- <div style="width: 250px; height: 200px; border: 2px dashed #ff6b6b;
123
- display: flex; align-items: center; justify-content: center;
124
- background-color: #fff5f5; border-radius: 8px; padding: 15px;
125
  text-align: center; color: #e03131; font-family: Arial, sans-serif;">
126
  <div>
127
- <h4 style="margin: 0 0 8px 0; color: #e03131;">🚫 Preview Not Available</h4>
128
  <p style="margin: 0; font-size: 12px;">Generated SVG contains errors:<br>
129
  <em style="font-size: 11px;">{result}</em></p>
130
  </div>
131
  </div>
132
  """
133
-
134
  # Clear cache after generation
135
  torch.cuda.empty_cache()
136
-
137
  return svg_code, svg_display
138
 
139
- # Authentication function using HF Space secrets
140
- def authenticate(username, password):
141
- """
142
- Authentication function for Gradio using HF Space secrets
143
- Returns True if credentials are valid, False otherwise
144
- """
145
- # Get credentials from HF Space secrets
146
- valid_username = os.getenv("user") # This matches your secret name "user"
147
- valid_password = os.getenv("password") # This matches your secret name "password"
148
-
149
- # Fallback credentials if secrets are not available (for local testing)
150
- if valid_username is None:
151
- valid_username = "user"
152
- print("Warning: 'user' secret not found, using fallback")
153
-
154
- if valid_password is None:
155
- valid_password = "password"
156
- print("Warning: 'password' secret not found, using fallback")
157
-
158
- return username == valid_username and password == valid_password
159
-
160
- # Minimal CSS for slightly larger HTML preview only
161
  custom_css = """
162
  div[data-testid="HTML"] {
163
- min-height: 320px !important;
 
 
 
164
  }
165
  """
166
 
167
- gradio_app = gr.Interface(
168
- fn=generate_svg,
169
- inputs=gr.Textbox(
170
- lines=2,
171
- placeholder="Describe the SVG you want (e.g., 'a red circle with blue border')..."
172
- ),
173
- outputs=[
174
- gr.Code(label="Generated SVG Code", language="html"),
175
- gr.HTML(label="SVG Preview")
176
- ],
177
- title="SVG Code Generator",
178
- description="Generate SVG code from natural language using a fine-tuned LLM.",
179
- css=custom_css
180
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  if __name__ == "__main__":
183
- gradio_app.launch(auth=(os.getenv("user"), os.getenv("password")), share=True, ssr_mode=False)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
  import torch
5
  import gc
6
  import spaces
 
12
  torch.cuda.empty_cache()
13
  gc.collect()
14
 
15
+ # Model configuration
16
+ BASE_MODEL = "Qwen/Qwen3-1.7B"
17
+ ADAPTER_REPO = "vinoku89/svg-lora-sft-1.7b"
18
+ ADAPTER_REVISION = "best" # or specific tag like "qwen1.7b-sft-e2-14.2k-loss0.34-20260124-001"
19
 
20
+ # Prompt configuration (from svg-sft-1.7b.yaml)
21
+ SYSTEM_PROMPT = "You are an expert SVG code generator. Generate precise, well-formed SVG code that accurately matches the given description."
22
+ USER_PREFIX = "Generate SVG code for the following description:\n\n"
23
+ USER_SUFFIX = "\n\nProvide only the SVG code, starting with <svg and ending with </svg>."
24
 
25
+ # Load model with LoRA adapter
26
+ print(f"Loading base model: {BASE_MODEL}")
27
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
28
 
29
+ base_model = AutoModelForCausalLM.from_pretrained(
30
+ BASE_MODEL,
 
 
 
 
 
 
 
 
31
  torch_dtype=torch.float16,
32
  device_map="auto",
33
  low_cpu_mem_usage=True,
34
+ trust_remote_code=True
35
+ )
36
+
37
+ print(f"Loading LoRA adapter: {ADAPTER_REPO} (revision: {ADAPTER_REVISION})")
38
+ model = PeftModel.from_pretrained(
39
+ base_model,
40
+ ADAPTER_REPO,
41
+ revision=ADAPTER_REVISION,
42
  )
43
+ model.eval()
44
+ print("Model loaded successfully!")
45
+
46
+
47
+ def format_prompt(description: str) -> str:
48
+ """
49
+ Format the prompt using Qwen3 chat template.
50
+ """
51
+ user_content = f"{USER_PREFIX}{description}{USER_SUFFIX}"
52
+
53
+ messages = [
54
+ {"role": "system", "content": SYSTEM_PROMPT},
55
+ {"role": "user", "content": user_content}
56
+ ]
57
+
58
+ # Use tokenizer's chat template
59
+ prompt = tokenizer.apply_chat_template(
60
+ messages,
61
+ tokenize=False,
62
+ add_generation_prompt=True
63
+ )
64
+
65
+ return prompt
66
+
67
+
68
+ def extract_svg(text: str) -> str:
69
+ """
70
+ Extract SVG code from generated text.
71
+ """
72
+ # Try to find complete SVG tags
73
+ svg_match = re.search(r'<svg[^>]*>.*?</svg>', text, re.DOTALL | re.IGNORECASE)
74
+ if svg_match:
75
+ return svg_match.group(0)
76
+
77
+ # If no complete SVG, try to find partial SVG and clean up
78
+ if '<svg' in text.lower():
79
+ # Find the start of SVG
80
+ start_idx = text.lower().find('<svg')
81
+ svg_content = text[start_idx:]
82
+
83
+ # If it doesn't end with </svg>, try to add it
84
+ if '</svg>' not in svg_content.lower():
85
+ # Find a good stopping point
86
+ svg_content = svg_content.split('<|')[0] # Stop at chat tokens
87
+ svg_content = svg_content.split('\n\n')[0] # Stop at double newline
88
+ if not svg_content.strip().endswith('</svg>'):
89
+ svg_content += '</svg>'
90
 
91
+ return svg_content
92
+
93
+ return text
94
+
95
+
96
+ def validate_svg(svg_content: str):
97
  """
98
+ Validate if SVG content is properly formatted and renderable.
99
  """
100
  try:
 
101
  svg_content = svg_content.strip()
102
+
103
  # If it doesn't start with <svg, try to extract SVG content
104
  if not svg_content.startswith('<svg'):
105
+ svg_content = extract_svg(svg_content)
106
+
107
+ # Ensure xmlns is present
108
+ if 'xmlns=' not in svg_content:
109
+ svg_content = svg_content.replace('<svg', '<svg xmlns="http://www.w3.org/2000/svg"', 1)
110
+
 
 
 
 
 
111
  # Parse XML to validate structure
112
  ET.fromstring(svg_content)
113
+
114
  return True, svg_content
115
+
116
  except ET.ParseError as e:
117
  return False, f"XML Parse Error: {str(e)}"
118
  except Exception as e:
119
  return False, f"Validation Error: {str(e)}"
120
 
121
+
122
+ @spaces.GPU(duration=120)
123
+ def generate_svg(description: str, temperature: float = 0.7, max_tokens: int = 2048):
124
+ """
125
+ Generate SVG code from a text description.
126
+ """
127
  # Clear cache before generation
128
  torch.cuda.empty_cache()
129
+
130
+ # Format the prompt
131
+ prompt = format_prompt(description)
132
+
133
+ # Tokenize
134
+ inputs = tokenizer(prompt, return_tensors="pt")
135
+
136
+ # Move inputs to model device
 
 
 
 
137
  if hasattr(model, 'device'):
138
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
139
+ elif torch.cuda.is_available():
140
+ inputs = {k: v.cuda() for k, v in inputs.items()}
141
+
142
+ # Generate
143
+ with torch.no_grad():
144
  outputs = model.generate(
145
  **inputs,
146
+ max_new_tokens=max_tokens,
147
+ temperature=temperature,
148
  do_sample=True,
149
+ top_p=0.95,
150
  pad_token_id=tokenizer.eos_token_id,
151
+ eos_token_id=tokenizer.eos_token_id,
152
  )
153
+
154
+ # Decode
155
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
156
+
157
+ # Extract SVG from the response
158
+ # The response should be after the assistant marker
159
+ if "<|im_start|>assistant" in generated_text:
160
+ svg_code = generated_text.split("<|im_start|>assistant")[-1]
161
+ svg_code = svg_code.replace("<|im_end|>", "").strip()
162
  else:
163
+ # Fallback: extract everything after the prompt
164
+ svg_code = generated_text[len(prompt):].strip()
165
+
166
+ # Extract and clean SVG
167
+ svg_code = extract_svg(svg_code)
168
+
169
  # Validate SVG
170
  is_valid, result = validate_svg(svg_code)
171
+
172
  if is_valid:
 
173
  validated_svg = result
174
+ # Ensure reasonable display size
175
+ if 'width=' not in validated_svg.lower():
176
+ validated_svg = validated_svg.replace('<svg', '<svg width="400" height="400"', 1)
177
  svg_display = validated_svg
178
  else:
 
179
  svg_display = f"""
180
+ <div style="width: 400px; height: 300px; border: 2px dashed #ff6b6b;
181
+ display: flex; align-items: center; justify-content: center;
182
+ background-color: #fff5f5; border-radius: 8px; padding: 15px;
183
  text-align: center; color: #e03131; font-family: Arial, sans-serif;">
184
  <div>
185
+ <h4 style="margin: 0 0 8px 0; color: #e03131;">Preview Not Available</h4>
186
  <p style="margin: 0; font-size: 12px;">Generated SVG contains errors:<br>
187
  <em style="font-size: 11px;">{result}</em></p>
188
  </div>
189
  </div>
190
  """
191
+
192
  # Clear cache after generation
193
  torch.cuda.empty_cache()
194
+
195
  return svg_code, svg_display
196
 
197
+
198
+ # Custom CSS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  custom_css = """
200
  div[data-testid="HTML"] {
201
+ min-height: 420px !important;
202
+ }
203
+ .gradio-container {
204
+ max-width: 900px !important;
205
  }
206
  """
207
 
208
+ # Create Gradio interface
209
+ with gr.Blocks(css=custom_css, title="SVG Code Generator") as gradio_app:
210
+ gr.Markdown("""
211
+ # SVG Code Generator
212
+
213
+ Generate SVG code from natural language descriptions using a fine-tuned Qwen3-1.7B model.
214
+
215
+ **Model:** `Qwen/Qwen3-1.7B` + LoRA adapter (`vinoku89/svg-lora-sft-1.7b`)
216
+ """)
217
+
218
+ with gr.Row():
219
+ with gr.Column(scale=1):
220
+ description_input = gr.Textbox(
221
+ label="Description",
222
+ placeholder="Describe the SVG you want (e.g., 'a red circle with blue border on white background')...",
223
+ lines=3
224
+ )
225
+
226
+ with gr.Row():
227
+ temperature = gr.Slider(
228
+ minimum=0.1, maximum=1.5, value=0.7, step=0.1,
229
+ label="Temperature"
230
+ )
231
+ max_tokens = gr.Slider(
232
+ minimum=256, maximum=4096, value=2048, step=256,
233
+ label="Max Tokens"
234
+ )
235
+
236
+ generate_btn = gr.Button("Generate SVG", variant="primary")
237
+
238
+ with gr.Column(scale=1):
239
+ svg_preview = gr.HTML(label="SVG Preview")
240
+
241
+ svg_code_output = gr.Code(label="Generated SVG Code", language="html", lines=15)
242
+
243
+ # Examples
244
+ gr.Examples(
245
+ examples=[
246
+ ["a simple red circle centered on a white background"],
247
+ ["a blue rectangle with rounded corners and a green border"],
248
+ ["a yellow star with 5 points"],
249
+ ["a gradient sunset with orange and purple colors"],
250
+ ["a simple house with a triangular roof and square windows"],
251
+ ],
252
+ inputs=description_input
253
+ )
254
+
255
+ # Connect the generate button
256
+ generate_btn.click(
257
+ fn=generate_svg,
258
+ inputs=[description_input, temperature, max_tokens],
259
+ outputs=[svg_code_output, svg_preview]
260
+ )
261
+
262
 
263
  if __name__ == "__main__":
264
+ # Authentication using HF Space secrets
265
+ auth = None
266
+ user = os.getenv("user")
267
+ password = os.getenv("password")
268
+ if user and password:
269
+ auth = (user, password)
270
+
271
+ gradio_app.launch(auth=auth, share=True, ssr_mode=False)