Spaces:
Runtime error
Runtime error
| """Visualizer for TAPAS | |
| Implementation heavily based on | |
| `EncodingVisualizer` from `tokenizers.tools`. | |
| """ | |
| import os | |
| from typing import Any, List, Dict | |
| from collections import defaultdict | |
| import pandas as pd | |
| from transformers import TapasTokenizer | |
| dirname = os.path.dirname(__file__) | |
| css_filename = os.path.join(dirname, "tapas-styles.css") | |
| with open(css_filename) as f: | |
| css = f.read() | |
| def HTMLBody(table_html: str, css_styles: str = css) -> str: | |
| """ | |
| Generates the full html with css from a list of html spans | |
| Args: | |
| table_html (str): | |
| The html string of the table | |
| css_styles (str): | |
| CSS styling to be embedded inline | |
| Returns: | |
| :obj:`str`: An HTML string with style markup | |
| """ | |
| return f""" | |
| <html> | |
| <head> | |
| <style> | |
| {css_styles} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="tokenized-text" dir=auto> | |
| {table_html} | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| class TapasVisualizer: | |
| def __init__(self, tokenizer: TapasTokenizer) -> None: | |
| self.tokenizer = tokenizer | |
| def normalize_token_str(self, token_str: str) -> str: | |
| # Normalize subword tokens to org subword str | |
| return token_str.replace("##", "") | |
| def style_span(self, span_text: str, css_classes: List[str]) -> str: | |
| css = f'''class="{' '.join(css_classes)}"''' | |
| return f"<span {css} >{span_text}</span>" | |
| def text_to_html(self, org_text: str, tokens: List[str]) -> str: | |
| """Create html based on the original text and its tokens. | |
| Note: The tokens need to be in same order as in the original text | |
| Args: | |
| org_text (str): Original string before tokenization | |
| tokens (List[str]): The tokens of org_text | |
| Returns: | |
| str: html with styling for the tokens | |
| """ | |
| if len(tokens) == 0: | |
| print(f"Empty tokens for: {org_text}") | |
| return "" | |
| cur_token_id = 0 | |
| cur_token = self.normalize_token_str(tokens[cur_token_id]) | |
| # Loop through each character | |
| next_start = 0 | |
| last_end = 0 | |
| spans = [] | |
| while next_start < len(org_text): | |
| candidate = org_text[next_start : next_start + len(cur_token)] | |
| # The tokenizer performs lowercasing; so check against lowercase | |
| if candidate.lower() == cur_token: | |
| if last_end != next_start: | |
| # There was token-less text (probably whitespace) | |
| # in the middle | |
| spans.append( | |
| self.style_span(org_text[last_end:next_start], ["non-token"]) | |
| ) | |
| odd_or_even = "even-token" if cur_token_id % 2 == 0 else "odd-token" | |
| spans.append(self.style_span(candidate, ["token", odd_or_even])) | |
| next_start += len(cur_token) | |
| last_end = next_start | |
| cur_token_id += 1 | |
| if cur_token_id >= len(tokens): | |
| break | |
| cur_token = self.normalize_token_str(tokens[cur_token_id]) | |
| else: | |
| next_start += 1 | |
| if last_end != len(org_text): | |
| spans.append(self.style_span(org_text[last_end:next_start], ["non-token"])) | |
| return spans | |
| def cells_to_html( | |
| self, | |
| cell_vals: List[List[str]], | |
| cell_tokens: Dict, | |
| row_id_start: int = 0, | |
| cell_element: str = "td", | |
| cumulative_cnt: int = 0, | |
| table_html: str = "", | |
| ) -> str: | |
| for row_id, row in enumerate(cell_vals, start=row_id_start): | |
| row_html = "" | |
| row_token_cnt = 0 | |
| for col_id, cell in enumerate(row, start=1): | |
| cur_cell_tokens = cell_tokens[(row_id, col_id)] | |
| span_htmls = self.text_to_html(cell, cur_cell_tokens) | |
| cell_html = "".join(span_htmls) | |
| row_html += f"<{cell_element}>{cell_html}</{cell_element}>" | |
| row_token_cnt += len(cur_cell_tokens) | |
| cumulative_cnt += row_token_cnt | |
| cnt_html = ( | |
| f'<td style="border: none;" align="right">' | |
| f'{self.style_span(str(cumulative_cnt), ["non-token", "count"])}' | |
| "</td>" | |
| f'<td style="border: none;" align="right">' | |
| f'{self.style_span(f"<+{row_token_cnt}", ["non-token", "count"])}' | |
| "</td>" | |
| ) | |
| row_html = cnt_html + row_html | |
| table_html += f"<tr>{row_html}</tr>" | |
| return table_html, cumulative_cnt | |
| def __call__(self, table: pd.DataFrame) -> Any: | |
| tokenized = self.tokenizer(table) | |
| cell_tokens = defaultdict(list) | |
| for id_ind, input_id in enumerate(tokenized["input_ids"]): | |
| input_id = int(input_id) | |
| # 'prev_label', 'column_rank', 'inv_column_rank', 'numeric_relation' | |
| # not required | |
| segment_id, col_id, row_id, *_ = tokenized["token_type_ids"][id_ind] | |
| token_text = self.tokenizer._convert_id_to_token(input_id) | |
| if int(segment_id) == 1: | |
| cell_tokens[(row_id, col_id)].append(token_text) | |
| table_html, cumulative_cnt = self.cells_to_html( | |
| cell_vals=[table.columns], | |
| cell_tokens=cell_tokens, | |
| row_id_start=0, | |
| cell_element="th", | |
| cumulative_cnt=0, | |
| table_html="", | |
| ) | |
| table_html, cumulative_cnt = self.cells_to_html( | |
| cell_vals=table.values, | |
| cell_tokens=cell_tokens, | |
| row_id_start=1, | |
| cell_element="td", | |
| cumulative_cnt=cumulative_cnt, | |
| table_html=table_html, | |
| ) | |
| top_label = self.style_span("#Tokens", ["count"]) | |
| top_label_cnt = self.style_span(f"(Total: {cumulative_cnt})", ["count"]) | |
| table_html = ( | |
| '<tr style="line-height: 2rem">' | |
| f'<td style="border: none;" colspan="2" align="left">{top_label}</td>' | |
| f'<td style="border: none;" colspan="1" align="left">{top_label_cnt}</td>' | |
| "</tr>" | |
| f"{table_html}" | |
| ) | |
| table_html = f"<table>{table_html}</table>" | |
| return HTMLBody(table_html) | |