Gradio / app.py
ChevalierJoseph's picture
Update app.py
d4f4f37 verified
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr
import torch
from threading import Thread
import re
import io
import zipfile
import tempfile
import os
from fontTools.ttLib import TTFont
from fontTools.fontBuilder import FontBuilder
from fontTools.pens.t2CharStringPen import T2CharStringPen
from fontTools.cffLib import PrivateDict
import traceback
def load_model():
tokenizer = AutoTokenizer.from_pretrained("ChevalierJoseph/typtop4")
model = AutoModelForCausalLM.from_pretrained("ChevalierJoseph/typtop4")
return tokenizer, model
def generate_svg(path_data, width=50, height=50):
svg_template = f"""
<svg width="{width}" height="{height}" viewBox="0 0 {width} {height}" xmlns="http://www.w3.org/2000/svg">
<path d="{path_data}" fill="black"/>
</svg>
"""
return svg_template
def extract_glyphs(text):
pattern = r"Glyph\s+([A-Z])\s+([MmZzLlHhVvCcSsQqTtAa0-9,\s\.\-]+?)(?=\s*Glyph\s+[A-Z]|\s*$)"
glyphs = re.findall(pattern, text)
return glyphs
def generate_glyphs_html(glyphs, cols=5, width=100, height=100):
html_parts = []
for lettre, path in glyphs:
svg_content = f"""
<svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}">
<g transform="translate(0, 0)">
<path d="{path.strip()}" fill="black"/>
</g>
</svg>
"""
html_parts.append(f"<div style='display: inline-block; margin: 10px; text-align: center;'><h3>{lettre}</h3>{svg_content}</div>")
grid_style = f"display: grid; grid-template-columns: repeat({cols}, 1fr); gap: 20px;"
return f'<div style="{grid_style}">{"".join(html_parts)}</div>'
def generate_svg_files(glyphs, width=100, height=100):
svg_files = {}
for lettre, path in glyphs:
svg_content = f"""
<svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}">
<g transform="translate(0, 0)">
<path d="{path.strip()}" fill="black"/>
</g>
</svg>
"""
svg_files[f"{lettre}.svg"] = svg_content
return svg_files
def draw_svg_path(pen, path_string):
commands = re.findall(r'[MLHVCSQTAZmlhvcsqtaz][^MLHVCSQTAZmlhvcsqtaz]*', path_string)
current_x, current_y = 0, 0
for command in commands:
cmd = command[0]
coords = re.findall(r'-?\d+(?:\.\d+)?', command[1:])
coords = [float(c) for c in coords]
if cmd == 'M':
if len(coords) >= 2:
current_x, current_y = coords[0], -coords[1]
pen.moveTo((current_x, current_y))
elif cmd == 'm':
if len(coords) >= 2:
current_x += coords[0]
current_y -= coords[1]
pen.moveTo((current_x, current_y))
elif cmd == 'L':
for i in range(0, len(coords), 2):
if i + 1 < len(coords):
current_x, current_y = coords[i], -coords[i + 1]
pen.lineTo((current_x, current_y))
elif cmd == 'l':
for i in range(0, len(coords), 2):
if i + 1 < len(coords):
current_x += coords[i]
current_y -= coords[i + 1]
pen.lineTo((current_x, current_y))
elif cmd == 'H':
for x in coords:
current_x = x
pen.lineTo((current_x, current_y))
elif cmd == 'h':
for dx in coords:
current_x += dx
pen.lineTo((current_x, current_y))
elif cmd == 'V':
for y in coords:
current_y = -y
pen.lineTo((current_x, current_y))
elif cmd == 'v':
for dy in coords:
current_y -= dy
pen.lineTo((current_x, current_y))
elif cmd == 'C':
for i in range(0, len(coords), 6):
if i + 5 < len(coords):
x1, y1 = coords[i], -coords[i + 1]
x2, y2 = coords[i + 2], -coords[i + 3]
current_x, current_y = coords[i + 4], -coords[i + 5]
pen.curveTo((x1, y1), (x2, y2), (current_x, current_y))
elif cmd == 'c':
for i in range(0, len(coords), 6):
if i + 5 < len(coords):
x1 = current_x + coords[i]
y1 = current_y - coords[i + 1]
x2 = current_x + coords[i + 2]
y2 = current_y - coords[i + 3]
current_x += coords[i + 4]
current_y -= coords[i + 5]
pen.curveTo((x1, y1), (x2, y2), (current_x, current_y))
elif cmd in ['Z', 'z']:
pen.closePath()
pen.endPath()
def create_otf_font(glyphs, output_path="font.otf"):
glyph_paths = {lettre: path for lettre, path in glyphs}
default_notdef_path = "M50 700H450V0H50Z M100 600L400 100 M400 600L100 100"
if '.notdef' not in glyph_paths:
glyph_paths['.notdef'] = default_notdef_path
glyph_names = list(glyph_paths.keys())
if glyph_names[0] != '.notdef':
glyph_names.insert(0, '.notdef')
unicode_map = {}
for glyph_name in glyph_paths:
if len(glyph_name) == 1 and glyph_name.isalpha():
unicode_value = ord(glyph_name.upper())
unicode_map[unicode_value] = glyph_name
private_dict = PrivateDict()
private_dict.nominalWidthX = 600
private_dict.defaultWidthX = 600
private_dict.blueValues = []
private_dict.otherBlues = []
private_dict.familyBlues = []
private_dict.familyOtherBlues = []
private_dict.BlueScale = 0.039625
private_dict.BlueShift = 7
private_dict.BlueFuzz = 1
private_dict.StdHW = [100]
private_dict.StdVW = [100]
char_strings = {}
for glyph_name, svg_path in glyph_paths.items():
try:
pen = T2CharStringPen(600, None)
draw_svg_path(pen, svg_path)
charstring = pen.getCharString(private=private_dict)
char_strings[glyph_name] = charstring
except Exception as e:
print(f"Error converting glyph {glyph_name}: {e}")
traceback.print_exc()
new_font_name = "Custom Glyph Font"
family_name = new_font_name
style_name = "Regular"
full_name = f"{family_name} {style_name}"
ps_name = f"{family_name.replace(' ', '')}-{style_name}"
name_strings = {
"familyName": family_name,
"styleName": style_name,
"fullName": full_name,
"psName": ps_name,
}
try:
fb = FontBuilder(1000, isTTF=False)
fb.setupGlyphOrder(glyph_names)
fb.setupCharacterMap(unicode_map)
top_dict_data = {"FullName": full_name}
fb.setupCFF(ps_name, top_dict_data, char_strings, {ps_name: private_dict})
advance_widths = {gn: 600 for gn in glyph_names}
metrics = {}
max_ascent = 0
min_descent = 0
for gn, cs in char_strings.items():
try:
bounds = cs.calcBounds(None)
lsb = int(round(bounds[0])) if bounds else 0
if bounds:
max_ascent = max(max_ascent, int(round(bounds[3])))
min_descent = min(min_descent, int(round(bounds[1])))
metrics[gn] = (advance_widths[gn], lsb)
except Exception as e:
print(f"Error calculating bounds for {gn}: {e}")
traceback.print_exc()
metrics[gn] = (advance_widths[gn], 0)
fb.setupHorizontalMetrics(metrics)
ascent = max(800, max_ascent)
descent = min(-200, min_descent)
fb.setupHorizontalHeader(ascent=ascent, descent=descent)
fb.setupNameTable(name_strings)
fb.setupOS2()
fb.setupPost()
except Exception as e:
print(f"Error building font: {e}")
traceback.print_exc()
return None
try:
fb.save(output_path)
test_font = TTFont(output_path)
test_font.close()
return output_path
except Exception as e:
print(f"Error saving font: {e}")
traceback.print_exc()
return None
def create_zip(svg_files):
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file:
zip_path = tmp_file.name
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for filename, content in svg_files.items():
zip_file.writestr(filename, content)
return zip_path
@spaces.GPU(duration=120)
def respond(message: str, system_message: str, max_tokens: int, temperature: float, top_p: float):
tokenizer, model = load_model()
if torch.cuda.is_available():
model = model.to('cuda')
model_device = next(model.parameters()).device
messages = [{"role": "system", "content": system_message}]
messages.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt",
).to(model_device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
do_sample_effective = True
if temperature == 0.0:
pass
generation_kwargs = {
"input_ids": inputs,
"streamer": streamer,
"max_new_tokens": max_tokens,
"temperature": float(temperature) if temperature > 0 else None,
"top_p": float(top_p) if top_p < 1.0 else None,
"do_sample": do_sample_effective,
"use_cache": True,
}
if temperature <= 0.01:
generation_kwargs["do_sample"] = False
if "temperature" in generation_kwargs:
del generation_kwargs["temperature"]
if "top_p" in generation_kwargs:
del generation_kwargs["top_p"]
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_response = ""
for new_text in streamer:
partial_response += new_text
glyphs = extract_glyphs(partial_response)
yield partial_response, glyphs
thread.join()
def create_demo():
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# TypTopType")
gr.Markdown("## Générateur de glyphes à partir de texte")
with gr.Row():
with gr.Column(scale=1):
msg = gr.Textbox(label="Entrée", placeholder="Tapez votre texte ici...")
send_btn = gr.Button("Envoyer", variant="primary", icon="📤")
with gr.Accordion("Paramètres avancés", open=False):
system_message = gr.Textbox(
value="Based on the following text, give me the svgpath of the glyphs from A to Z.",
label="Message système"
)
max_tokens = gr.Slider(
minimum=1, maximum=9048, value=9048, step=1, label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P"
)
cols = gr.Slider(
minimum=1, maximum=10, value=5, step=1, label="Colonnes"
)
width = gr.Slider(
minimum=50, maximum=200, value=100, step=10, label="Largeur"
)
height = gr.Slider(
minimum=50, maximum=200, value=100, step=10, label="Hauteur"
)
download_btn = gr.Button("Télécharger SVG", variant="primary", icon="📁")
otf_export_btn = gr.Button("Exporter OTF", variant="primary", icon="📱")
with gr.Column(scale=3):
gr.Markdown("## Aperçu")
svg_preview = gr.HTML(label="Aperçu SVG")
download_output = gr.File(label="Télécharger ZIP")
demo.css = """
.gr-box {
border-radius: 10px;
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
}
.gr-button {
border-radius: 5px;
}
"""
glyphs_state = gr.State([])
message_history = gr.State([])
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, system_message, max_tokens, temperature, top_p, cols, width, height):
message = history[-1][0]
response_generator = respond(message, system_message, max_tokens, temperature, top_p)
full_response, glyphs_list = "", []
for partial_response, glyphs in response_generator:
full_response = partial_response
if glyphs:
glyphs_list = glyphs
svg_html = generate_glyphs_html(glyphs_list, cols=cols, width=width, height=height)
else:
svg_html = "No glyphs found."
yield svg_html, glyphs_list
def download_svg(glyphs, width, height):
if not glyphs:
return None
svg_files = generate_svg_files(glyphs, width=width, height=height)
zip_path = create_zip(svg_files)
return zip_path
def export_to_otf(glyphs):
if not glyphs:
return None
with tempfile.NamedTemporaryFile(delete=False, suffix=".otf") as tmp_file:
otf_path = tmp_file.name
created_path = create_otf_font(glyphs, otf_path)
if created_path:
return created_path
else:
return None
send_btn.click(
user, [msg, message_history], [msg, message_history], queue=False
).then(
bot,
[message_history, system_message, max_tokens, temperature, top_p, cols, width, height],
[svg_preview, glyphs_state]
)
msg.submit(
user, [msg, message_history], [msg, message_history], queue=False
).then(
bot,
[message_history, system_message, max_tokens, temperature, top_p, cols, width, height],
[svg_preview, glyphs_state]
)
download_btn.click(
download_svg,
inputs=[glyphs_state, width, height],
outputs=download_output
)
otf_export_btn.click(
export_to_otf,
inputs=[glyphs_state],
outputs=download_output
)
return demo
demo = create_demo()
if __name__ == "__main__":
demo.launch()