ChevalierJoseph commited on
Commit
2ea46b3
·
verified ·
1 Parent(s): 8a1ae18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -104
app.py CHANGED
@@ -1,88 +1,78 @@
1
- import re
2
- import os
3
- import tempfile
4
- import zipfile
5
- from fontTools.ttLib import TTFont
6
- from svgpathtools import parse_path, Path, Line, CubicBezier, Arc
7
- from svgpathtools.path import Move # Importation correct
8
- import gradio as gr
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
10
  import torch
11
  from threading import Thread
 
 
 
 
 
 
12
 
13
  def load_model():
14
  tokenizer = AutoTokenizer.from_pretrained("ChevalierJoseph/typtop")
15
  model = AutoModelForCausalLM.from_pretrained("ChevalierJoseph/typtop")
16
  return tokenizer, model
17
 
 
 
 
 
 
 
 
 
18
  def extract_glyphs(text):
19
  pattern = r"Glyph\s+([A-Z])\s+([MmZzLlHhVvCcSsQqTtAa0-9,\s\.\-]+?)(?=\s*Glyph\s+[A-Z]|\s*$)"
20
  glyphs = re.findall(pattern, text)
21
  return glyphs
22
 
23
- def build_glyph_outlines(glyphs):
24
- outlines = {}
25
- for letter, path_data in glyphs:
26
- path = parse_path(path_data)
27
- outlines[letter] = path
28
- return outlines
29
-
30
- def generate_otf_directly(glyphs, output_path="output.otf"):
31
- font = TTFont()
32
-
33
- # Ajoutez une table cmap simple pour Unicode
34
- cmap = font["cmap"] = ttLib.tables.c_m_a_p.CmapTable()
35
- cmap.tables = [ttLib.tables.c_m_a_p.CmapSubtable.newSubtable(3, 1, 0)]
36
-
37
- # Ajoutez une table hmtx simple
38
- font["hmtx"] = ttLib.tables.h_m_t_x.hmtx()
39
- font["hmtx"].metrics = {}
40
-
41
- # Ajoutez une table maxp simple
42
- font["maxp"] = ttLib.tables.m_a_x_p.maxp()
43
- font["maxp"].numGlyphs = len(glyphs)
44
-
45
- # Ajoutez une table head simple
46
- font["head"] = ttLib.tables.h_e_a_d.head()
47
- font["head"].unitsPerEm = 1000
48
-
49
- # Ajoutez une table hhea simple
50
- font["hhea"] = ttLib.tables.h_h_e_a.hhea()
51
- font["hhea"].ascent = 800
52
- font["hhea"].descent = -300
53
-
54
- # Ajoutez une table loca simple
55
- font["loca"] = ttLib.tables.loca.loca()
56
-
57
- for letter, path_data in glyphs:
58
- glyph_name = f"glyph_{ord(letter)}"
59
- # Créez un glyphe vide
60
- glyph = ttLib.tables.g_l_y_f.Glyph()
61
- glyph.numberOfContours = 0 # Placeholder
62
- font["glyf"][glyph_name] = glyph
63
-
64
- # Ajoutez des métriques pour le glyphe
65
- font["hmtx"].metrics[glyph_name] = (1000, 0)
66
-
67
- # Ajoutez une entrée cmap pour le glyphe
68
- cmap.tables[0].cmap[ord(letter)] = glyph_name
69
 
70
- font.save(output_path)
71
- return output_path
72
 
73
  def generate_svg_files(glyphs, width=100, height=100):
74
  svg_files = {}
75
  for lettre, path in glyphs:
76
  svg_content = f"""
77
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}">
78
- <g transform="translate(0, 0)">
79
- <path d="{path.strip()}" fill="black"/>
80
- </g>
81
  </svg>
82
  """
83
  svg_files[f"{lettre}.svg"] = svg_content
84
  return svg_files
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def create_zip(svg_files):
87
  with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
88
  zip_path = tmp_file.name
@@ -91,21 +81,8 @@ def create_zip(svg_files):
91
  zip_file.writestr(filename, content)
92
  return zip_path
93
 
94
- def generate_glyphs_html(glyphs, cols=5, width=100, height=100):
95
- html_parts = []
96
- for lettre, path in glyphs:
97
- svg_content = f"""
98
- <svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}">
99
- <g transform="translate(0, 0)">
100
- <path d="{path.strip()}" fill="black"/>
101
- </g>
102
- </svg>
103
- """
104
- html_parts.append(f"<div style='display: inline-block; margin: 10px; text-align: center;'><h3>{lettre}</h3>{svg_content}</div>")
105
- grid_style = f"display: grid; grid-template-columns: repeat({cols}, 1fr); gap: 20px;"
106
- return f'<div style="{grid_style}">{"".join(html_parts)}</div>'
107
-
108
- def respond(message, system_message, max_tokens, temperature, top_p):
109
  tokenizer, model = load_model()
110
  if torch.cuda.is_available():
111
  model = model.to('cuda')
@@ -113,29 +90,30 @@ def respond(message, system_message, max_tokens, temperature, top_p):
113
 
114
  messages = [{"role": "system", "content": system_message}]
115
  messages.append({"role": "user", "content": message})
116
-
117
  inputs = tokenizer.apply_chat_template(
118
- messages,
119
- tokenize=True,
120
- add_generation_prompt=True,
121
- return_tensors="pt",
122
  ).to(model_device)
123
 
124
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
125
  generation_kwargs = {
126
  "input_ids": inputs,
127
  "streamer": streamer,
128
  "max_new_tokens": max_tokens,
129
  "temperature": float(temperature) if temperature > 0 else None,
130
  "top_p": float(top_p) if top_p < 1.0 else None,
131
- "do_sample": True,
132
  "use_cache": True,
133
  }
134
-
135
  if temperature <= 0.01:
136
  generation_kwargs["do_sample"] = False
137
- generation_kwargs.pop("temperature", None)
138
- generation_kwargs.pop("top_p", None)
 
 
139
 
140
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
141
  thread.start()
@@ -145,6 +123,7 @@ def respond(message, system_message, max_tokens, temperature, top_p):
145
  partial_response += new_text
146
  glyphs = extract_glyphs(partial_response)
147
  yield partial_response, glyphs
 
148
  thread.join()
149
 
150
  def create_demo():
@@ -152,26 +131,39 @@ def create_demo():
152
  gr.Markdown("# TypTopType")
153
  glyphs_state = gr.State([])
154
  message_history = gr.State([])
 
155
  with gr.Row():
156
  with gr.Column(scale=1):
157
- msg = gr.Textbox(label="Input description of the typography")
158
  system_message = gr.Textbox(
159
  value="Based on the following text, give me the svgpath of the glyphs from A to Z.",
160
  visible=False
161
  )
162
- max_tokens = gr.Slider(minimum=1, maximum=9048, value=9048, step=1, visible=False)
163
- temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, visible=False)
164
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, visible=False)
165
- cols = gr.Slider(minimum=1, maximum=10, value=5, step=1, visible=False)
166
- width = gr.Slider(minimum=50, maximum=200, value=100, step=10, visible=False)
167
- height = gr.Slider(minimum=50, maximum=200, value=100, step=10, visible=False)
168
- download_btn = gr.Button("Download SVG files")
169
- download_otf_btn = gr.Button("Download OTF font")
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  with gr.Column(scale=3):
171
- gr.Markdown("## Preview")
172
  svg_preview = gr.HTML(label="SVG Preview")
173
  download_output = gr.File(label="Download ZIP")
174
- download_otf_output = gr.File(label="Download OTF")
175
 
176
  def user(user_message, history):
177
  return "", history + [[user_message, None]]
@@ -184,7 +176,9 @@ def create_demo():
184
  full_response = partial_response
185
  if glyphs:
186
  glyphs_list = glyphs
187
- svg_html = generate_glyphs_html(glyphs_list, cols=cols, width=width, height=height) if glyphs_list else "No glyphs found."
 
 
188
  yield svg_html, glyphs_list
189
 
190
  def download_svg(glyphs, width, height):
@@ -194,24 +188,36 @@ def create_demo():
194
  zip_path = create_zip(svg_files)
195
  return zip_path
196
 
197
- def download_otf(glyphs):
198
  if not glyphs:
199
  return None
200
- return generate_otf_directly(glyphs)
201
-
202
- msg.submit(user, [msg, message_history], [msg, message_history], queue=False).then(
203
- bot, [message_history, system_message, max_tokens, temperature, top_p, cols, width, height], [svg_preview, glyphs_state]
 
 
 
 
 
 
204
  )
205
 
206
  download_btn.click(
207
- download_svg, inputs=[glyphs_state, width, height], outputs=download_output
 
 
208
  )
209
- download_otf_btn.click(
210
- download_otf, inputs=[glyphs_state], outputs=download_otf_output
 
 
 
211
  )
212
 
213
- return demo
214
 
215
  demo = create_demo()
 
216
  if __name__ == "__main__":
217
  demo.launch()
 
1
+ import spaces
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ import gradio as gr
4
  import torch
5
  from threading import Thread
6
+ import re
7
+ import io
8
+ import zipfile
9
+ import tempfile
10
+ import os
11
+ import subprocess
12
 
13
  def load_model():
14
  tokenizer = AutoTokenizer.from_pretrained("ChevalierJoseph/typtop")
15
  model = AutoModelForCausalLM.from_pretrained("ChevalierJoseph/typtop")
16
  return tokenizer, model
17
 
18
+ def generate_svg(path_data, width=50, height=50):
19
+ svg_template = f"""
20
+ <svg width="{width}" height="{height}" viewBox="0 0 {width} {height}" xmlns="http://www.w3.org/2000/svg">
21
+ <path d="{path_data}" fill="black"/>
22
+ </svg>
23
+ """
24
+ return svg_template
25
+
26
  def extract_glyphs(text):
27
  pattern = r"Glyph\s+([A-Z])\s+([MmZzLlHhVvCcSsQqTtAa0-9,\s\.\-]+?)(?=\s*Glyph\s+[A-Z]|\s*$)"
28
  glyphs = re.findall(pattern, text)
29
  return glyphs
30
 
31
+ def generate_glyphs_html(glyphs, cols=5, width=100, height=100):
32
+ html_parts = []
33
+ for lettre, path in glyphs:
34
+ svg_content = f"""
35
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}">
36
+ <g transform="translate(0, 0)">
37
+ <path d="{path.strip()}" fill="black"/>
38
+ </g>
39
+ </svg>
40
+ """
41
+ html_parts.append(f"<div style='display: inline-block; margin: 10px; text-align: center;'><h3>{lettre}</h3>{svg_content}</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ grid_style = f"display: grid; grid-template-columns: repeat({cols}, 1fr); gap: 20px;"
44
+ return f'<div style="{grid_style}">{"".join(html_parts)}</div>'
45
 
46
  def generate_svg_files(glyphs, width=100, height=100):
47
  svg_files = {}
48
  for lettre, path in glyphs:
49
  svg_content = f"""
50
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}">
51
+ <g transform="translate(0, 0)">
52
+ <path d="{path.strip()}" fill="black"/>
53
+ </g>
54
  </svg>
55
  """
56
  svg_files[f"{lettre}.svg"] = svg_content
57
  return svg_files
58
 
59
+ def create_font_svg(glyphs, output_path="font.svg"):
60
+ with open(output_path, 'w') as font_file:
61
+ font_file.write('<svg xmlns="http://www.w3.org/2000/svg">\n')
62
+ for lettre, path in glyphs:
63
+ unicode_val = ord(lettre)
64
+ font_file.write(f' <glyph unicode="&#{unicode_val};" d="{path.strip()}"/>\n')
65
+ font_file.write('</svg>\n')
66
+ return output_path
67
+
68
+ def convert_svg_to_otf(svg_path, otf_path="font.otf"):
69
+ try:
70
+ subprocess.run(['svg2ttf', svg_path, otf_path], check=True)
71
+ return otf_path
72
+ except subprocess.CalledProcessError as e:
73
+ print(f"Error in converting SVG to OTF: {e}")
74
+ return None
75
+
76
  def create_zip(svg_files):
77
  with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
78
  zip_path = tmp_file.name
 
81
  zip_file.writestr(filename, content)
82
  return zip_path
83
 
84
+ @spaces.GPU(duration=180)
85
+ def respond(message: str, system_message: str, max_tokens: int, temperature: float, top_p: float):
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  tokenizer, model = load_model()
87
  if torch.cuda.is_available():
88
  model = model.to('cuda')
 
90
 
91
  messages = [{"role": "system", "content": system_message}]
92
  messages.append({"role": "user", "content": message})
 
93
  inputs = tokenizer.apply_chat_template(
94
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt",
 
 
 
95
  ).to(model_device)
96
 
97
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
98
+
99
+ do_sample_effective = True
100
+ if temperature == 0.0:
101
+ pass
102
  generation_kwargs = {
103
  "input_ids": inputs,
104
  "streamer": streamer,
105
  "max_new_tokens": max_tokens,
106
  "temperature": float(temperature) if temperature > 0 else None,
107
  "top_p": float(top_p) if top_p < 1.0 else None,
108
+ "do_sample": do_sample_effective,
109
  "use_cache": True,
110
  }
 
111
  if temperature <= 0.01:
112
  generation_kwargs["do_sample"] = False
113
+ if "temperature" in generation_kwargs:
114
+ del generation_kwargs["temperature"]
115
+ if "top_p" in generation_kwargs:
116
+ del generation_kwargs["top_p"]
117
 
118
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
119
  thread.start()
 
123
  partial_response += new_text
124
  glyphs = extract_glyphs(partial_response)
125
  yield partial_response, glyphs
126
+
127
  thread.join()
128
 
129
  def create_demo():
 
131
  gr.Markdown("# TypTopType")
132
  glyphs_state = gr.State([])
133
  message_history = gr.State([])
134
+
135
  with gr.Row():
136
  with gr.Column(scale=1):
137
+ msg = gr.Textbox(label="input box, type here")
138
  system_message = gr.Textbox(
139
  value="Based on the following text, give me the svgpath of the glyphs from A to Z.",
140
  visible=False
141
  )
142
+ max_tokens = gr.Slider(
143
+ minimum=1, maximum=9048, value=9048, step=1, visible=False
144
+ )
145
+ temperature = gr.Slider(
146
+ minimum=0.1, maximum=4.0, value=0.7, step=0.1, visible=False
147
+ )
148
+ top_p = gr.Slider(
149
+ minimum=0.1, maximum=1.0, value=0.95, step=0.05, visible=False
150
+ )
151
+ cols = gr.Slider(
152
+ minimum=1, maximum=10, value=5, step=1, visible=False
153
+ )
154
+ width = gr.Slider(
155
+ minimum=50, maximum=200, value=100, step=10, visible=False
156
+ )
157
+ height = gr.Slider(
158
+ minimum=50, maximum=200, value=100, step=10, visible=False
159
+ )
160
+ download_btn = gr.Button("Download svg file")
161
+ otf_export_btn = gr.Button("Export OTF")
162
+
163
  with gr.Column(scale=3):
164
+ gr.Markdown("## preview")
165
  svg_preview = gr.HTML(label="SVG Preview")
166
  download_output = gr.File(label="Download ZIP")
 
167
 
168
  def user(user_message, history):
169
  return "", history + [[user_message, None]]
 
176
  full_response = partial_response
177
  if glyphs:
178
  glyphs_list = glyphs
179
+ svg_html = generate_glyphs_html(glyphs_list, cols=cols, width=width, height=height)
180
+ else:
181
+ svg_html = "No glyphs found."
182
  yield svg_html, glyphs_list
183
 
184
  def download_svg(glyphs, width, height):
 
188
  zip_path = create_zip(svg_files)
189
  return zip_path
190
 
191
+ def export_to_otf(glyphs):
192
  if not glyphs:
193
  return None
194
+ svg_path = create_font_svg(glyphs)
195
+ otf_path = convert_svg_to_otf(svg_path)
196
+ return otf_path
197
+
198
+ msg.submit(
199
+ user, [msg, message_history], [msg, message_history], queue=False
200
+ ).then(
201
+ bot,
202
+ [message_history, system_message, max_tokens, temperature, top_p, cols, width, height],
203
+ [svg_preview, glyphs_state]
204
  )
205
 
206
  download_btn.click(
207
+ download_svg,
208
+ inputs=[glyphs_state, width, height],
209
+ outputs=download_output
210
  )
211
+
212
+ otf_export_btn.click(
213
+ export_to_otf,
214
+ inputs=[glyphs_state],
215
+ outputs=download_output
216
  )
217
 
218
+ return demo
219
 
220
  demo = create_demo()
221
+
222
  if __name__ == "__main__":
223
  demo.launch()