Spaces:
Runtime error
Runtime error
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| import gradio as gr | |
| import torch | |
| from threading import Thread | |
| import re | |
| import io | |
| import zipfile | |
| import tempfile | |
| import os | |
| from fontTools.ttLib import TTFont | |
| from fontTools.fontBuilder import FontBuilder | |
| from fontTools.pens.t2CharStringPen import T2CharStringPen | |
| from fontTools.cffLib import PrivateDict | |
| import traceback | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("ChevalierJoseph/typtop4") | |
| model = AutoModelForCausalLM.from_pretrained("ChevalierJoseph/typtop4") | |
| return tokenizer, model | |
| def generate_svg(path_data, width=50, height=50): | |
| svg_template = f""" | |
| <svg width="{width}" height="{height}" viewBox="0 0 {width} {height}" xmlns="http://www.w3.org/2000/svg"> | |
| <path d="{path_data}" fill="black"/> | |
| </svg> | |
| """ | |
| return svg_template | |
| def extract_glyphs(text): | |
| pattern = r"Glyph\s+([A-Z])\s+([MmZzLlHhVvCcSsQqTtAa0-9,\s\.\-]+?)(?=\s*Glyph\s+[A-Z]|\s*$)" | |
| glyphs = re.findall(pattern, text) | |
| return glyphs | |
| def generate_glyphs_html(glyphs, cols=5, width=100, height=100): | |
| html_parts = [] | |
| for lettre, path in glyphs: | |
| svg_content = f""" | |
| <svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}"> | |
| <g transform="translate(0, 0)"> | |
| <path d="{path.strip()}" fill="black"/> | |
| </g> | |
| </svg> | |
| """ | |
| html_parts.append(f"<div style='display: inline-block; margin: 10px; text-align: center;'><h3>{lettre}</h3>{svg_content}</div>") | |
| grid_style = f"display: grid; grid-template-columns: repeat({cols}, 1fr); gap: 20px;" | |
| return f'<div style="{grid_style}">{"".join(html_parts)}</div>' | |
| def generate_svg_files(glyphs, width=100, height=100): | |
| svg_files = {} | |
| for lettre, path in glyphs: | |
| svg_content = f""" | |
| <svg xmlns="http://www.w3.org/2000/svg" viewBox="-100 -800 900 900" width="{width}" height="{height}"> | |
| <g transform="translate(0, 0)"> | |
| <path d="{path.strip()}" fill="black"/> | |
| </g> | |
| </svg> | |
| """ | |
| svg_files[f"{lettre}.svg"] = svg_content | |
| return svg_files | |
| def draw_svg_path(pen, path_string): | |
| commands = re.findall(r'[MLHVCSQTAZmlhvcsqtaz][^MLHVCSQTAZmlhvcsqtaz]*', path_string) | |
| current_x, current_y = 0, 0 | |
| for command in commands: | |
| cmd = command[0] | |
| coords = re.findall(r'-?\d+(?:\.\d+)?', command[1:]) | |
| coords = [float(c) for c in coords] | |
| if cmd == 'M': | |
| if len(coords) >= 2: | |
| current_x, current_y = coords[0], -coords[1] | |
| pen.moveTo((current_x, current_y)) | |
| elif cmd == 'm': | |
| if len(coords) >= 2: | |
| current_x += coords[0] | |
| current_y -= coords[1] | |
| pen.moveTo((current_x, current_y)) | |
| elif cmd == 'L': | |
| for i in range(0, len(coords), 2): | |
| if i + 1 < len(coords): | |
| current_x, current_y = coords[i], -coords[i + 1] | |
| pen.lineTo((current_x, current_y)) | |
| elif cmd == 'l': | |
| for i in range(0, len(coords), 2): | |
| if i + 1 < len(coords): | |
| current_x += coords[i] | |
| current_y -= coords[i + 1] | |
| pen.lineTo((current_x, current_y)) | |
| elif cmd == 'H': | |
| for x in coords: | |
| current_x = x | |
| pen.lineTo((current_x, current_y)) | |
| elif cmd == 'h': | |
| for dx in coords: | |
| current_x += dx | |
| pen.lineTo((current_x, current_y)) | |
| elif cmd == 'V': | |
| for y in coords: | |
| current_y = -y | |
| pen.lineTo((current_x, current_y)) | |
| elif cmd == 'v': | |
| for dy in coords: | |
| current_y -= dy | |
| pen.lineTo((current_x, current_y)) | |
| elif cmd == 'C': | |
| for i in range(0, len(coords), 6): | |
| if i + 5 < len(coords): | |
| x1, y1 = coords[i], -coords[i + 1] | |
| x2, y2 = coords[i + 2], -coords[i + 3] | |
| current_x, current_y = coords[i + 4], -coords[i + 5] | |
| pen.curveTo((x1, y1), (x2, y2), (current_x, current_y)) | |
| elif cmd == 'c': | |
| for i in range(0, len(coords), 6): | |
| if i + 5 < len(coords): | |
| x1 = current_x + coords[i] | |
| y1 = current_y - coords[i + 1] | |
| x2 = current_x + coords[i + 2] | |
| y2 = current_y - coords[i + 3] | |
| current_x += coords[i + 4] | |
| current_y -= coords[i + 5] | |
| pen.curveTo((x1, y1), (x2, y2), (current_x, current_y)) | |
| elif cmd in ['Z', 'z']: | |
| pen.closePath() | |
| pen.endPath() | |
| def create_otf_font(glyphs, output_path="font.otf"): | |
| glyph_paths = {lettre: path for lettre, path in glyphs} | |
| default_notdef_path = "M50 700H450V0H50Z M100 600L400 100 M400 600L100 100" | |
| if '.notdef' not in glyph_paths: | |
| glyph_paths['.notdef'] = default_notdef_path | |
| glyph_names = list(glyph_paths.keys()) | |
| if glyph_names[0] != '.notdef': | |
| glyph_names.insert(0, '.notdef') | |
| unicode_map = {} | |
| for glyph_name in glyph_paths: | |
| if len(glyph_name) == 1 and glyph_name.isalpha(): | |
| unicode_value = ord(glyph_name.upper()) | |
| unicode_map[unicode_value] = glyph_name | |
| private_dict = PrivateDict() | |
| private_dict.nominalWidthX = 600 | |
| private_dict.defaultWidthX = 600 | |
| private_dict.blueValues = [] | |
| private_dict.otherBlues = [] | |
| private_dict.familyBlues = [] | |
| private_dict.familyOtherBlues = [] | |
| private_dict.BlueScale = 0.039625 | |
| private_dict.BlueShift = 7 | |
| private_dict.BlueFuzz = 1 | |
| private_dict.StdHW = [100] | |
| private_dict.StdVW = [100] | |
| char_strings = {} | |
| for glyph_name, svg_path in glyph_paths.items(): | |
| try: | |
| pen = T2CharStringPen(600, None) | |
| draw_svg_path(pen, svg_path) | |
| charstring = pen.getCharString(private=private_dict) | |
| char_strings[glyph_name] = charstring | |
| except Exception as e: | |
| print(f"Error converting glyph {glyph_name}: {e}") | |
| traceback.print_exc() | |
| new_font_name = "Custom Glyph Font" | |
| family_name = new_font_name | |
| style_name = "Regular" | |
| full_name = f"{family_name} {style_name}" | |
| ps_name = f"{family_name.replace(' ', '')}-{style_name}" | |
| name_strings = { | |
| "familyName": family_name, | |
| "styleName": style_name, | |
| "fullName": full_name, | |
| "psName": ps_name, | |
| } | |
| try: | |
| fb = FontBuilder(1000, isTTF=False) | |
| fb.setupGlyphOrder(glyph_names) | |
| fb.setupCharacterMap(unicode_map) | |
| top_dict_data = {"FullName": full_name} | |
| fb.setupCFF(ps_name, top_dict_data, char_strings, {ps_name: private_dict}) | |
| advance_widths = {gn: 600 for gn in glyph_names} | |
| metrics = {} | |
| max_ascent = 0 | |
| min_descent = 0 | |
| for gn, cs in char_strings.items(): | |
| try: | |
| bounds = cs.calcBounds(None) | |
| lsb = int(round(bounds[0])) if bounds else 0 | |
| if bounds: | |
| max_ascent = max(max_ascent, int(round(bounds[3]))) | |
| min_descent = min(min_descent, int(round(bounds[1]))) | |
| metrics[gn] = (advance_widths[gn], lsb) | |
| except Exception as e: | |
| print(f"Error calculating bounds for {gn}: {e}") | |
| traceback.print_exc() | |
| metrics[gn] = (advance_widths[gn], 0) | |
| fb.setupHorizontalMetrics(metrics) | |
| ascent = max(800, max_ascent) | |
| descent = min(-200, min_descent) | |
| fb.setupHorizontalHeader(ascent=ascent, descent=descent) | |
| fb.setupNameTable(name_strings) | |
| fb.setupOS2() | |
| fb.setupPost() | |
| except Exception as e: | |
| print(f"Error building font: {e}") | |
| traceback.print_exc() | |
| return None | |
| try: | |
| fb.save(output_path) | |
| test_font = TTFont(output_path) | |
| test_font.close() | |
| return output_path | |
| except Exception as e: | |
| print(f"Error saving font: {e}") | |
| traceback.print_exc() | |
| return None | |
| def create_zip(svg_files): | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_file: | |
| zip_path = tmp_file.name | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
| for filename, content in svg_files.items(): | |
| zip_file.writestr(filename, content) | |
| return zip_path | |
| def respond(message: str, system_message: str, max_tokens: int, temperature: float, top_p: float): | |
| tokenizer, model = load_model() | |
| if torch.cuda.is_available(): | |
| model = model.to('cuda') | |
| model_device = next(model.parameters()).device | |
| messages = [{"role": "system", "content": system_message}] | |
| messages.append({"role": "user", "content": message}) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", | |
| ).to(model_device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| do_sample_effective = True | |
| if temperature == 0.0: | |
| pass | |
| generation_kwargs = { | |
| "input_ids": inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_tokens, | |
| "temperature": float(temperature) if temperature > 0 else None, | |
| "top_p": float(top_p) if top_p < 1.0 else None, | |
| "do_sample": do_sample_effective, | |
| "use_cache": True, | |
| } | |
| if temperature <= 0.01: | |
| generation_kwargs["do_sample"] = False | |
| if "temperature" in generation_kwargs: | |
| del generation_kwargs["temperature"] | |
| if "top_p" in generation_kwargs: | |
| del generation_kwargs["top_p"] | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial_response = "" | |
| for new_text in streamer: | |
| partial_response += new_text | |
| glyphs = extract_glyphs(partial_response) | |
| yield partial_response, glyphs | |
| thread.join() | |
| def create_demo(): | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# TypTopType") | |
| gr.Markdown("## Générateur de glyphes à partir de texte") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| msg = gr.Textbox(label="Entrée", placeholder="Tapez votre texte ici...") | |
| send_btn = gr.Button("Envoyer", variant="primary", icon="📤") | |
| with gr.Accordion("Paramètres avancés", open=False): | |
| system_message = gr.Textbox( | |
| value="Based on the following text, give me the svgpath of the glyphs from A to Z.", | |
| label="Message système" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1, maximum=9048, value=9048, step=1, label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P" | |
| ) | |
| cols = gr.Slider( | |
| minimum=1, maximum=10, value=5, step=1, label="Colonnes" | |
| ) | |
| width = gr.Slider( | |
| minimum=50, maximum=200, value=100, step=10, label="Largeur" | |
| ) | |
| height = gr.Slider( | |
| minimum=50, maximum=200, value=100, step=10, label="Hauteur" | |
| ) | |
| download_btn = gr.Button("Télécharger SVG", variant="primary", icon="📁") | |
| otf_export_btn = gr.Button("Exporter OTF", variant="primary", icon="📱") | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Aperçu") | |
| svg_preview = gr.HTML(label="Aperçu SVG") | |
| download_output = gr.File(label="Télécharger ZIP") | |
| demo.css = """ | |
| .gr-box { | |
| border-radius: 10px; | |
| box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1); | |
| } | |
| .gr-button { | |
| border-radius: 5px; | |
| } | |
| """ | |
| glyphs_state = gr.State([]) | |
| message_history = gr.State([]) | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| def bot(history, system_message, max_tokens, temperature, top_p, cols, width, height): | |
| message = history[-1][0] | |
| response_generator = respond(message, system_message, max_tokens, temperature, top_p) | |
| full_response, glyphs_list = "", [] | |
| for partial_response, glyphs in response_generator: | |
| full_response = partial_response | |
| if glyphs: | |
| glyphs_list = glyphs | |
| svg_html = generate_glyphs_html(glyphs_list, cols=cols, width=width, height=height) | |
| else: | |
| svg_html = "No glyphs found." | |
| yield svg_html, glyphs_list | |
| def download_svg(glyphs, width, height): | |
| if not glyphs: | |
| return None | |
| svg_files = generate_svg_files(glyphs, width=width, height=height) | |
| zip_path = create_zip(svg_files) | |
| return zip_path | |
| def export_to_otf(glyphs): | |
| if not glyphs: | |
| return None | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".otf") as tmp_file: | |
| otf_path = tmp_file.name | |
| created_path = create_otf_font(glyphs, otf_path) | |
| if created_path: | |
| return created_path | |
| else: | |
| return None | |
| send_btn.click( | |
| user, [msg, message_history], [msg, message_history], queue=False | |
| ).then( | |
| bot, | |
| [message_history, system_message, max_tokens, temperature, top_p, cols, width, height], | |
| [svg_preview, glyphs_state] | |
| ) | |
| msg.submit( | |
| user, [msg, message_history], [msg, message_history], queue=False | |
| ).then( | |
| bot, | |
| [message_history, system_message, max_tokens, temperature, top_p, cols, width, height], | |
| [svg_preview, glyphs_state] | |
| ) | |
| download_btn.click( | |
| download_svg, | |
| inputs=[glyphs_state, width, height], | |
| outputs=download_output | |
| ) | |
| otf_export_btn.click( | |
| export_to_otf, | |
| inputs=[glyphs_state], | |
| outputs=download_output | |
| ) | |
| return demo | |
| demo = create_demo() | |
| if __name__ == "__main__": | |
| demo.launch() |