Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from typing import Iterable, List, Dict, Tuple | |
| import gradio as gr | |
| from gradio.themes.base import Base | |
| from gradio.themes.soft import Soft | |
| from gradio.themes.monochrome import Monochrome | |
| from gradio.themes.default import Default | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import spaces | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, pipeline | |
| import os | |
| import colorsys | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| from typing import Tuple | |
| import plotly.io as pio | |
| from wordcloud import WordCloud | |
| import io | |
| def hex_to_rgb(hex_color: str) -> tuple[int, int, int]: | |
| hex_color = hex_color.lstrip('#') | |
| return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) | |
| def rgb_to_hex(rgb_color: tuple[int, int, int]) -> str: | |
| return "#{:02x}{:02x}{:02x}".format(*rgb_color) | |
| def adjust_brightness(rgb_color: tuple[int, int, int], factor: float) -> tuple[int, int, int]: | |
| hsv_color = colorsys.rgb_to_hsv(*[v / 255.0 for v in rgb_color]) | |
| new_v = max(0, min(hsv_color[2] * factor, 1)) | |
| new_rgb = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], new_v) | |
| return tuple(int(v * 255) for v in new_rgb) | |
| monochrome = Monochrome() | |
| auth_token = os.environ['HF_TOKEN'] | |
| tokenizer_bin = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_token", token=auth_token) | |
| model_bin = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_token", token=auth_token) | |
| tokenizer_bin.model_max_length = 512 | |
| pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin) | |
| tokenizer_ext = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token) | |
| model_ext = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token) | |
| tokenizer_ext.model_max_length = 512 | |
| pipe_ext = pipeline("ner", model=model_ext, tokenizer=tokenizer_ext) | |
| model1 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_Int_segment", num_labels=1, token=auth_token) | |
| tokenizer1 = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_Int_segment", token=auth_token) | |
| model2 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_seq_ext", num_labels=1, token=auth_token) | |
| def process_ner(text: str, pipeline) -> dict: | |
| output = pipeline(text) | |
| entities = [] | |
| current_entity = None | |
| for token in output: | |
| entity_type = token['entity'][2:] | |
| entity_prefix = token['entity'][:1] | |
| if current_entity is None or entity_type != current_entity['entity'] or (entity_prefix == 'B' and entity_type == current_entity['entity']): | |
| if current_entity is not None: | |
| entities.append(current_entity) | |
| current_entity = { | |
| "entity": entity_type, | |
| "start": token['start'], | |
| "end": token['end'], | |
| "score": token['score'] | |
| } | |
| else: | |
| current_entity['end'] = token['end'] | |
| current_entity['score'] = max(current_entity['score'], token['score']) | |
| if current_entity is not None: | |
| entities.append(current_entity) | |
| return {"entities": entities} | |
| def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str, str, str]: | |
| inputs1 = tokenizer1(text, max_length=512, return_tensors='pt', truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs1 = model1(**inputs1) | |
| outputs2 = model2(**inputs1) | |
| prediction1 = outputs1[0].item() | |
| prediction2 = outputs2[0].item() | |
| score = prediction1 / (prediction2 + prediction1) | |
| return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}" | |
| import plotly.graph_objects as go | |
| from typing import Tuple | |
| def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure, np.ndarray]: | |
| entities_bin = [entity['entity'] for entity in ner_output_bin['entities']] | |
| entities_ext = [entity['entity'] for entity in ner_output_ext['entities']] | |
| # Counting entities for binary classification | |
| entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)} | |
| bin_labels = list(entity_counts_bin.keys()) | |
| bin_sizes = list(entity_counts_bin.values()) | |
| # Counting entities for extended classification | |
| entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)} | |
| ext_labels = list(entity_counts_ext.keys()) | |
| ext_sizes = list(entity_counts_ext.values()) | |
| # Define color mapping | |
| bin_color_map = { | |
| "External": "#6ad5bc", | |
| "Internal": "#ee8bac" | |
| } | |
| ext_color_map = { | |
| "INTemothou": "#FF7F50", # Coral | |
| "INTpercept": "#FF4500", # OrangeRed | |
| "INTtime": "#FF6347", # Tomato | |
| "INTplace": "#FFD700", # Gold | |
| "INTevent": "#FFA500", # Orange | |
| "EXTsemantic": "#4682B4", # SteelBlue | |
| "EXTrepetition": "#5F9EA0", # CadetBlue | |
| "EXTother": "#00CED1", # DarkTurquoise | |
| } | |
| bin_colors = [bin_color_map[label] for label in bin_labels] | |
| ext_colors = [ext_color_map[label] for label in ext_labels] | |
| # Create pie chart for extended classification | |
| fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))]) | |
| fig1.update_layout( | |
| #title_text='Extended Sequence Classification Subclasses', | |
| template='plotly_dark', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| # Create bar chart for binary classification | |
| fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))]) | |
| fig2.update_layout( | |
| #title='Binary Sequence Classification Classes', | |
| xaxis_title='Entity Type', | |
| yaxis_title='Count', | |
| template='plotly_dark', | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)' | |
| ) | |
| # Generate word cloud | |
| wordcloud_image = generate_wordcloud(ner_output_ext['entities'], ext_color_map) | |
| return fig1, fig2, wordcloud_image | |
| def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray: | |
| entity_texts = [entity['entity'] for entity in entities] | |
| entity_scores = [entity['score'] for entity in entities] | |
| entity_types = [entity['entity'] for entity in entities] | |
| # Create a dictionary for word cloud | |
| word_freq = {text: score for text, score in zip(entity_texts, entity_scores)} | |
| def color_func(word, font_size, position, orientation, random_state=None, **kwargs): | |
| entity_type = next(entity['entity'] for entity in entities if entity['entity'] == word) | |
| return color_map.get(entity_type, "#FFFFFF") | |
| wordcloud = WordCloud(width=800, height=400, background_color='black', color_func=color_func).generate_from_frequencies(word_freq) | |
| # Convert to image array | |
| plt.figure(figsize=(10, 5)) | |
| plt.imshow(wordcloud, interpolation='bilinear') | |
| plt.axis('off') | |
| plt.tight_layout(pad=0) | |
| # Convert plt to numpy array | |
| plt_image = plt.gcf() | |
| plt_image.canvas.draw() | |
| image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8) | |
| image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close() | |
| return image_array | |
| def all(text: str): | |
| ner_output_bin = process_ner(text, pipe_bin) | |
| ner_output_ext = process_ner(text, pipe_ext) | |
| classification_output = process_classification(text, model1, model2, tokenizer1) | |
| pie_chart, bar_chart, wordcloud_image = generate_charts(ner_output_bin, ner_output_ext) | |
| return (ner_output_bin, ner_output_ext, | |
| classification_output[0], classification_output[1], classification_output[2], | |
| pie_chart, bar_chart, wordcloud_image) | |
| iface = gr.Interface( | |
| fn=all, | |
| inputs=gr.Textbox(lines=5, label="Input Text", placeholder="Write about how your breakfast went or anything else that happened or might happen to you ..."), | |
| outputs=[ | |
| gr.HighlightedText(label="Binary Sequence Classification", | |
| color_map={ | |
| "External": "#6ad5bcff", | |
| "Internal": "#ee8bacff"} | |
| ), | |
| gr.HighlightedText(label="Extended Sequence Classification", | |
| color_map={ | |
| "INTemothou": "#FF7F50", # Coral | |
| "INTpercept": "#FF4500", # OrangeRed | |
| "INTtime": "#FF6347", # Tomato | |
| "INTplace": "#FFD700", # Gold | |
| "INTevent": "#FFA500", # Orange | |
| "EXTsemantic": "#4682B4", # SteelBlue | |
| "EXTrepetition": "#5F9EA0", # CadetBlue | |
| "EXTother": "#00CED1", # DarkTurquoise | |
| } | |
| ), | |
| gr.Label(label="Internal Detail Count"), | |
| gr.Label(label="External Detail Count"), | |
| gr.Label(label="Approximated Internal Detail Ratio"), | |
| gr.Plot(label="Extended SeqClass Entity Distribution Pie Chart"), | |
| gr.Plot(label="Binary SeqClass Entity Count Bar Chart"), | |
| gr.Image(label="Entity Word Cloud") | |
| ], | |
| title="Scoring Demo", | |
| description="Autobiographical Memory Analysis: This demo combines two text - and two sequence classification models to showcase our automated Autobiographical Interview scoring method. Submit a narrative to see the results.", | |
| examples=examples, | |
| theme=monochrome | |
| ) | |
| iface.launch() |