Spaces:
Runtime error
Runtime error
| import subprocess | |
| import jinja2 | |
| import gradio | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| subprocess.run( | |
| ["curl", "--output", "checkpoint.pkl", "https://storage.googleapis.com/ithaca-resources/models/checkpoint_v1.pkl"]) | |
| #@article{asssome2022restoring, | |
| # title = {Restoring and attributing ancient texts using deep neural networks}, | |
| # author = {Assael*, Yannis and Sommerschield*, Thea and Shillingford, Brendan and Bordbar, Mahyar and Pavlopoulos, John and Chatzipanagiotou, Marita and Androutsopoulos, Ion and Prag, Jonathan and de Freitas, Nando}, | |
| # doi = {10.1038/s41586-022-04448-z}, | |
| # journal = {Nature}, | |
| # year = {2022} | |
| #} | |
| # Copyright 2021 the Ithaca Authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Example for running inference. See also colab.""" | |
| import functools | |
| import pickle | |
| from ithaca.eval import inference | |
| from ithaca.models.model import Model | |
| from ithaca.util.alphabet import GreekAlphabet | |
| import jax | |
| def create_time_plot(attribution): | |
| class dataset_config: | |
| date_interval = 10 | |
| date_max = 800 | |
| date_min = -800 | |
| def bce_ad(d): | |
| if d < 0: | |
| return f'{abs(d)} BCE' | |
| elif d > 0: | |
| return f'{abs(d)} AD' | |
| return 0 | |
| #compute scores | |
| date_pred_y = np.array(attribution.year_scores) | |
| date_pred_x = np.arange( | |
| dataset_config.date_min + dataset_config.date_interval / 2, | |
| dataset_config.date_max + dataset_config.date_interval / 2, | |
| dataset_config.date_interval) | |
| date_pred_argmax = date_pred_y.argmax( | |
| ) * dataset_config.date_interval + dataset_config.date_min + dataset_config.date_interval // 2 | |
| date_pred_avg = np.dot(date_pred_y, date_pred_x) | |
| # Plot figure | |
| fig = plt.figure(figsize=(10, 5), dpi=100) | |
| plt.bar(date_pred_x, date_pred_y, color='#f2c852', width=10., label='Ithaca distribution') | |
| plt.axvline(x=date_pred_avg, color='#67ac5b', linewidth=2., label='Ithaca average') | |
| plt.ylabel('Probability', fontsize=14) | |
| yticks = np.arange(0, 1.1, 0.1) | |
| yticks_str = list(map(lambda x: f'{int(x*100)}%', yticks)) | |
| plt.yticks(yticks, yticks_str, fontsize=12, rotation=0) | |
| plt.ylim(0, int((date_pred_y.max()+0.1)*10)/10) | |
| plt.xlabel('Date', fontsize=14) | |
| xticks = list(range(dataset_config.date_min, dataset_config.date_max + 1, 25)) | |
| xticks_str = list(map(bce_ad, xticks)) | |
| plt.xticks(xticks, xticks_str, fontsize=12, rotation=0) | |
| plt.xlim(int(date_pred_avg - 100), int(date_pred_avg + 100)) | |
| plt.legend(loc='upper right', fontsize=12) | |
| #encode to base64 for html parsing | |
| tmpfile = BytesIO() | |
| fig.savefig(tmpfile, format='png') | |
| encoded = base64.b64encode(tmpfile.getvalue()).decode('utf-8') | |
| html = '<div>' + '<img src="data:image/png;charset=utf-8;base64,{}">'.format(encoded) + '</div>' | |
| return html | |
| def get_subregion_name(id, region_map): | |
| return region_map['sub']['names_inv'][region_map['sub']['ids_inv'][id]] | |
| def load_checkpoint(path): | |
| """Loads a checkpoint pickle. | |
| Args: | |
| path: path to checkpoint pickle | |
| Returns: | |
| a model config dictionary (arguments to the model's constructor), a dict of | |
| dicts containing region mapping information, a GreekAlphabet instance with | |
| indices and words populated from the checkpoint, a dict of Jax arrays | |
| `params`, and a `forward` function. | |
| """ | |
| # Pickled checkpoint dict containing params and various config: | |
| with open(path, 'rb') as f: | |
| checkpoint = pickle.load(f) | |
| # We reconstruct the model using the same arguments as during training, which | |
| # are saved as a dict in the "model_config" key, and construct a `forward` | |
| # function of the form required by attribute() and restore(). | |
| params = jax.device_put(checkpoint['params']) | |
| model = Model(**checkpoint['model_config']) | |
| forward = functools.partial(model.apply, params) | |
| # Contains the mapping between region IDs and names: | |
| region_map = checkpoint['region_map'] | |
| # Use vocabulary mapping from the checkpoint, the rest of the values in the | |
| # class are fixed and constant e.g. the padding symbol | |
| alphabet = GreekAlphabet() | |
| alphabet.idx2word = checkpoint['alphabet']['idx2word'] | |
| alphabet.word2idx = checkpoint['alphabet']['word2idx'] | |
| return checkpoint['model_config'], region_map, alphabet, params, forward | |
| def main(text): | |
| restore_template = jinja2.Template("""<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <link rel="preconnect" href="https://fonts.googleapis.com"> | |
| <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> | |
| <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400&family=Roboto:wght@400&display=swap" rel="stylesheet"> | |
| <style> | |
| body { | |
| font-family: 'Roboto Mono', monospace; | |
| font-weight: 400; | |
| } | |
| .container { | |
| overflow-x: scroll; | |
| scroll-behavior: smooth; | |
| } | |
| table { | |
| table-layout: fixed; | |
| font-size: 16px; | |
| padding: 0; | |
| white-space: nowrap; | |
| } | |
| table tr:first-child { | |
| font-weight: bold; | |
| } | |
| table td { | |
| border-bottom: 1px solid #ccc; | |
| padding: 3px 0; | |
| } | |
| table td.header { | |
| font-family: Roboto, Helvetica, sans-serif; | |
| text-align: right; | |
| position: -webkit-sticky; | |
| position: sticky; | |
| background-color: white; | |
| } | |
| .header-1 { | |
| background-color: white; | |
| width: 120px; | |
| min-width: 120px; | |
| max-width: 120px; | |
| left: 0; | |
| } | |
| .header-2 { | |
| left: 120px; | |
| width: 50px; | |
| max-width: 50px; | |
| min-width: 50px; | |
| padding-right: 5px; | |
| } | |
| table td:not(.header) { | |
| border-left: 1px solid black; | |
| padding-left: 5px; | |
| } | |
| .header-2col { | |
| width: 170px; | |
| min-width: 170px; | |
| max-width: 170px; | |
| left: 0; | |
| padding-right: 5px; | |
| } | |
| .pred { | |
| background: #ddd; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <table cellspacing="0"> | |
| <tr> | |
| <td colspan="2" class="header header-2col">Input text:</td> | |
| <td> | |
| {% for char in restoration_results.input_text -%} | |
| {%- if loop.index0 in prediction_idx -%} | |
| <span class="pred">{{char}}</span> | |
| {%- else -%} | |
| {{char}} | |
| {%- endif -%} | |
| {%- endfor %} | |
| </td> | |
| </tr> | |
| <!-- Predictions: --> | |
| {% for pred in restoration_results.predictions[:3] %} | |
| <tr> | |
| <td class="header header-1">Hypothesis {{ loop.index }}:</td> | |
| <td class="header header-2">{{ "%.1f%%"|format(100 * pred.score) }}</td> | |
| <td> | |
| {% for char in pred.text -%} | |
| {%- if loop.index0 in prediction_idx -%} | |
| <span class="pred">{{char}}</span> | |
| {%- else -%} | |
| {{char}} | |
| {%- endif -%} | |
| {%- endfor %} | |
| </td> | |
| </tr> | |
| {% endfor %} | |
| </table> | |
| </div> | |
| <script> | |
| document.querySelector('#btn').addEventListener('click', () => { | |
| const pred = document.querySelector(".pred"); | |
| pred.scrollIntoViewIfNeeded(); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """) | |
| if not 50 <= len(text) <= 750: | |
| raise app.UsageError( | |
| f'Text should be between 50 and 750 chars long, but the input was ' | |
| f'{len(input_text)} characters') | |
| # Load the checkpoint pickle and extract from it the pieces needed for calling | |
| # the attribute() and restore() functions: | |
| (model_config, region_map, alphabet, params, | |
| forward) = load_checkpoint('checkpoint.pkl') | |
| vocab_char_size = model_config['vocab_char_size'] | |
| vocab_word_size = model_config['vocab_word_size'] | |
| attribution = inference.attribute( | |
| text, | |
| forward=forward, | |
| params=params, | |
| alphabet=alphabet, | |
| region_map=region_map, | |
| vocab_char_size=vocab_char_size, | |
| vocab_word_size=vocab_word_size) | |
| restoration = inference.restore( | |
| text, | |
| forward=forward, | |
| params=params, | |
| alphabet=alphabet, | |
| vocab_char_size=vocab_char_size, | |
| vocab_word_size=vocab_word_size) | |
| prediction_idx = set(i for i, c in enumerate(restoration.input_text) if c == '?') | |
| attrib_dict = {get_subregion_name(l.location_id, region_map): l.score for l in attribution.locations[:3]} | |
| return restore_template.render( | |
| restoration_results=restoration, | |
| prediction_idx=prediction_idx), attrib_dict, create_time_plot(attribution) | |
| with open('example_input.txt', encoding='utf8') as f: | |
| examples = [line for line in f] | |
| gradio.Interface( | |
| main, | |
| inputs=gradio.inputs.Textbox(lines=3), | |
| outputs=[gradio.outputs.HTML(label='Restoration'), gradio.outputs.Label(label='Geographical Attribution'), gradio.outputs.HTML(label='Chronological Attribution')], | |
| examples=examples, | |
| title='Spaces Demo for Ithaca', | |
| description='Restoration and Attribution of ancient Greek texts made by DeepMind. Represent missing characters as "-", and characters to be predicted as "?" (up to 10, does not need to be consecutive)<br> <br><a href="https://ithaca.deepmind.com/" target="_blank">blogpost</a>').launch(enable_queue=True) | |