Spaces:
Runtime error
Runtime error
| import traceback | |
| from io import StringIO | |
| from typing import Optional | |
| import gradio as gr | |
| import pandas as pd | |
| from loguru import logger | |
| from utils import pipeline | |
| from utils.models import list_models | |
| def read_data(filepath: str) -> Optional[pd.DataFrame]: | |
| if filepath.endswith('.xlsx'): | |
| df = pd.read_excel(filepath) | |
| elif filepath.endswith('.csv'): | |
| df = pd.read_csv(filepath) | |
| else: | |
| raise Exception('File type not supported') | |
| return df | |
| def process( | |
| task_name: str, | |
| model_name: str, | |
| pooling: str, | |
| text: str, | |
| file=None, | |
| ) -> (None, pd.DataFrame, str): | |
| try: | |
| logger.info(f'Processing {task_name} with {model_name} and {pooling}') | |
| # load file | |
| if file: | |
| df = read_data(file.name) | |
| elif text: | |
| string_io = StringIO(text) | |
| df = pd.read_csv(string_io) | |
| assert len(df) >= 1, 'No input data' | |
| else: | |
| raise Exception('No input data') | |
| # check | |
| if len(df) > 10000: | |
| raise Exception('Data exceeds 10,000 rows') | |
| # process | |
| if task_name == 'Originality': | |
| df = pipeline.p0_originality(df, model_name, pooling) | |
| elif task_name == 'Flexibility': | |
| df = pipeline.p1_flexibility(df, model_name, pooling) | |
| else: | |
| raise Exception('Task not supported') | |
| # save | |
| path = 'output.csv' | |
| df.to_csv(path, index=False, encoding='utf-8-sig') | |
| return None, df.iloc[:10], path | |
| except: | |
| error = traceback.format_exc() | |
| logger.warning({ | |
| 'error': error, | |
| 'task_name': task_name, | |
| 'model_name': model_name, | |
| 'pooling': pooling, | |
| 'text': text, | |
| 'file': file, | |
| }) | |
| return {'Info': 'Something wrong', 'Error': traceback.format_exc()}, None, None | |
| # input | |
| task_name_dropdown = gr.components.Dropdown( | |
| label='Task Name', | |
| value='Originality', | |
| choices=['Originality', 'Flexibility'] | |
| ) | |
| model_name_dropdown = gr.components.Dropdown( | |
| label='Model Name', | |
| value=list_models[0], | |
| choices=list_models | |
| ) | |
| pooling_dropdown = gr.components.Dropdown( | |
| label='Pooling', | |
| value='mean', | |
| choices=['mean', 'cls'] | |
| ) | |
| text_input = gr.components.Textbox( | |
| value=open('data/example_xlm.csv', 'r').read(), | |
| lines=10, | |
| ) | |
| file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx']) | |
| # output | |
| text_output = gr.components.Textbox(label='Output') | |
| dataframe_output = gr.components.Dataframe(label='DataFrame') | |
| file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx']) | |
| app = gr.Interface( | |
| fn=process, | |
| inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input], | |
| outputs=[text_output, dataframe_output, file_output], | |
| description=open('data/description.txt', 'r').read(), | |
| title='TransDis-CreativityAutoAssessment', | |
| concurrency_limit=1, | |
| ) | |
| app.launch(max_threads=1) | |