Spaces:
Runtime error
Runtime error
| from huggingface_hub import list_models | |
| import streamlit as st | |
| from model import ReplicateModel | |
| import os | |
| import pandas as pd | |
| DATASETS_PATH = 'datasets' | |
| models = { | |
| 'mistral_instruct': ReplicateModel('mistralai/mistral-7b-instruct-v0.1:83b6a56e7c828e667f21fd596c338fd4f0039b46bcfa18d973e8e70e455fda70'), | |
| } | |
| prompts = { | |
| 'simple_prompt': | |
| ''' | |
| I have topic that is described by the following keywords: [KEYWORDS] | |
| Based on the information above, extract a short topic label in the following format: | |
| topic: <topic label> | |
| ''', | |
| 'few_shot_examples': | |
| ''' | |
| I have a topic that is described by the following keywords: [KEYWORDS] | |
| Example 1: | |
| Keywords: apple,fruit,healthy,snack,red,orchard | |
| Topic label: Healthy Fruit Snacks | |
| Example 2: | |
| Keywords: computer,technology,silicon,programming,internet,hardware | |
| Topic label: Computer Technology | |
| Example 3: | |
| Keywords: democracy,government,elections,vote,political,representation | |
| Topic label: Democratic Governance | |
| Based on the information above, extract a short topic label in the following format: | |
| topic: <topic label> | |
| ''' | |
| # 'custom_prompt': '' | |
| } | |
| topicsets = { | |
| 'example_topics': os.path.join(DATASETS_PATH, 'topics.csv'), | |
| } | |
| def get_available_models(): | |
| # return [model.modelId for model in list_models(author='textminr')] | |
| return models.keys() | |
| def load_model(model_name: str): | |
| # model = AutoGPTQForCausalLM.from_quantized(model_name, device_map='auto') | |
| # return pipeline('text-generation', model=model, tokenizer=model_name) | |
| return models[model_name].load() | |
| st.set_page_config(page_title='TL playground', page_icon='🚀', layout='wide') | |
| st.title('🚀 Topic Labelling playground') | |
| percentage_width_main = 70 | |
| st.markdown( | |
| f'''<style> | |
| @media only screen and (min-width: 1500px) {{ | |
| .appview-container .main .block-container{{ | |
| max-width: {percentage_width_main}%; | |
| }} | |
| }} | |
| </style> | |
| ''', | |
| unsafe_allow_html=True, | |
| ) | |
| col1, col2 = st.columns(2, gap='medium') | |
| sel_model_name = col1.selectbox('Select a model', models, index=None, placeholder='Select a model') | |
| if sel_model_name: | |
| model = load_model(sel_model_name) | |
| sel_dataset_name = col1.selectbox('Select a dataset', topicsets.keys(), index=None) | |
| if sel_dataset_name: | |
| sel_dataset = pd.read_csv(topicsets[sel_dataset_name]) | |
| sel_dataset.drop(columns=['topic_id', 'domain'], inplace=True) | |
| col1.dataframe(sel_dataset) | |
| sel_row_index = col1.selectbox('Select a topic', sel_dataset.index) | |
| sel_prompt = col2.selectbox('Select a prompt', prompts.keys()) | |
| if sel_prompt != 'custom_prompt': | |
| col2.code(prompts[sel_prompt], language='text') | |
| sel_prompt_text = prompts[sel_prompt] | |
| else: | |
| sel_prompt_text = st.text_area('Custom prompt', height=200) | |
| col2.caption('Make sure to use "[KEYWORDS]" to indicate where the keywords should be inserted.') | |
| btn_generate = col2.button('Generate', disabled=(sel_model_name is None or sel_dataset_name is None)) | |
| if btn_generate: | |
| keywords = ','.join(sel_dataset.iloc[sel_row_index].tolist()[1:]) | |
| placeholder = col2.empty() | |
| with placeholder, st.spinner('Generating...'): | |
| prompt = sel_prompt_text.replace('[KEYWORDS]', keywords) | |
| # result = model(prompt, max_new_tokens=100, return_full_text=False)[0]['generated_text'] | |
| result = model.generate(prompt) | |
| message = col2.chat_message("ai") | |
| message.write(result) | |
| message.caption('Keywords: ' + keywords) | |