Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import gradio as gr | |
| import os | |
| import spacy | |
| from spacy import displacy | |
| title = """ | |
| # 🙋🏻♂️Welcome to 🌟Tonic's 🎅🏻⌚OCRonos Vintage Text Gen | |
| This app generates historical-style text using the OCRonos-Vintage model. You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. You can see a tokenized visualisation of the output and your input, and learn english using the visualization for the output text! | |
| ### Join us : | |
| 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
| """ | |
| model_name = "PleIAs/OCRonos-Vintage" | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| os.system('python -m spacy download en_core_web_sm') | |
| nlp = spacy.load("en_core_web_sm") | |
| def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0): | |
| # with torch.no_grad(): | |
| prompt = f"### Text ###\n{prompt}" | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| output = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.eos_token_id, | |
| top_k=top_k, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| repetition_penalty=repetition_penalty, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| if "### Correction ###" in generated_text: | |
| generated_text = generated_text.split("### Correction ###")[1].strip() | |
| tokens = tokenizer.tokenize(generated_text) | |
| highlighted_text = [] | |
| for token in tokens: | |
| clean_token = token.replace("Ġ", "") | |
| token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "") | |
| highlighted_text.append((clean_token, token_type)) | |
| del inputs, input_ids, attention_mask, output, tokens | |
| torch.cuda.empty_cache() | |
| return highlighted_text, generated_text | |
| def text_analysis(text): | |
| doc = nlp(text) | |
| html = displacy.render(doc, style="dep", page=True) | |
| html = ( | |
| "<div style='max-width:100%; max-height:360px; overflow:auto'>" | |
| + html | |
| + "</div>" | |
| ) | |
| pos_count = { | |
| "char_count": len(text), | |
| "token_count": len(list(doc)), | |
| } | |
| pos_tokens = [(token.text, token.pos_) for token in doc] | |
| return pos_tokens, pos_count, html | |
| def generate_dependency_parse(generated_text): | |
| tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text) | |
| return html_generated | |
| def display_dependency_parse(generated_text): | |
| return generate_dependency_parse(generated_text) | |
| def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty): | |
| # Generate historical-style text and tokenized output | |
| generated_highlight, generated_text = historical_generation( | |
| prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty | |
| ) | |
| # Analyze input text (dependency parse visualization) | |
| tokens_input, pos_count_input, html_input = text_analysis(prompt) | |
| # Generate dependency parse for the generated text | |
| dependency_parse_generated_html = generate_dependency_parse(generated_text) | |
| # Set the visibility of the generated text and highlight components | |
| return (generated_text, generated_highlight, pos_count_input, html_input, | |
| gr.update(visible=True), dependency_parse_generated_html, | |
| gr.update(visible=True), gr.update(visible=False)) | |
| def reset_interface(): | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| with gr.Blocks(theme=gr.themes.Base()) as iface: | |
| gr.Markdown(title) | |
| prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine' he said", lines=2) | |
| max_new_tokens = gr.Slider(label="📏Length", minimum=50, maximum=1000, step=5, value=320) | |
| top_k = gr.Slider(label="🧪Sampling", minimum=1, maximum=100, step=1, value=50) | |
| temperature = gr.Slider(label="🎨Creativity", minimum=0.1, maximum=1, step=0.05, value=0.3) | |
| top_p = gr.Slider(label="👌🏻Quality", minimum=0.1, maximum=0.99, step=0.01, value=0.97) | |
| repetition_penalty = gr.Slider(label="🔴Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.3) | |
| generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage") | |
| highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True) | |
| tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)") | |
| dependency_parse_input = gr.HTML(label="👁️Visualization") | |
| dependency_parse_generated = gr.HTML(label="🎅🏻⌚Dependency Parse Visualization (Generated Text)") | |
| send_button = gr.Button(value="🎅🏻⌚OCRonos-Vintage 👁️Visualization", visible=False) | |
| reset_button = gr.Button(value="♻️Start Again", visible=False) | |
| generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text") | |
| generate_button.click( | |
| full_interface, | |
| inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty], | |
| outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button] | |
| ) | |
| send_button.click( | |
| display_dependency_parse, | |
| inputs=[generated_text_output], | |
| outputs=[dependency_parse_generated] | |
| ) | |
| reset_button.click( | |
| reset_interface, | |
| inputs=None, | |
| outputs=[generate_button, send_button, reset_button, generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, dependency_parse_generated] | |
| ) | |
| iface.launch() |