Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from hub_name import LORA_HUB_NAMES | |
| from random import shuffle | |
| import pandas as pd | |
| import streamlit as st | |
| import contextlib | |
| from functools import wraps | |
| from io import StringIO | |
| import contextlib | |
| import redirect as rd | |
| import torch | |
| import shutil | |
| import os | |
| css = """ | |
| <style> | |
| .stDataFrame { width: 100% !important; } | |
| </style> | |
| """ | |
| st.markdown(css, unsafe_allow_html=True) | |
| def main(): | |
| st.title("LoraHub") | |
| st.markdown("Low-rank adaptations (LoRA) are techniques for fine-tuning large language models on new tasks. We propose LoraHub, a framework that allows composing multiple LoRA modules trained on different tasks. The goal is to achieve good performance on unseen tasks using just a few examples, without needing extra parameters or training. And we want to build a marketplace where users can share their trained LoRA modules, thereby facilitating the application of these modules to new tasks.") | |
| st.markdown("In this demo, you will use avaiable lora modules selected in the left sidebar to tackle your few-shot examples. When the LoraHub learning is done, you can download the final LoRA module and use it for your new task. You can check out more details in our [paper](https://huggingface.co/papers/2307.13269).") | |
| with st.sidebar: | |
| st.title("LoRA Module Pool") | |
| st.markdown( | |
| "The following modules are available for you to compose for your new task. Every module name is a peft repository in Huggingface Hub, and you can find them [here](https://huggingface.co/models?search=lorahub).") | |
| df = pd.DataFrame({ | |
| "Index": list(range(len(LORA_HUB_NAMES))), | |
| "Module Name": LORA_HUB_NAMES, | |
| }) | |
| st.data_editor(df, | |
| disabled=["LoRA Module", "Index"], | |
| hide_index=True) | |
| st.multiselect( | |
| 'Select your favorite modules as the candidate for LoRA composition', | |
| list(range(len(LORA_HUB_NAMES))), | |
| [], | |
| key="select_names") | |
| def set_lucky_modules(): | |
| names = list(range(len(LORA_HUB_NAMES))) | |
| shuffle(names) | |
| names = names[:20] | |
| st.session_state["select_names"] = names | |
| st.button(":game_die: Give 20 Lucky Modules", | |
| on_click=set_lucky_modules) | |
| st.write('We will use the following modules', [ | |
| LORA_HUB_NAMES[i] for i in st.session_state["select_names"]]) | |
| st.subheader("Prepare your few-shot examples") | |
| txt_input = st.text_area('Examples Inputs (One Line One Input)', | |
| ''' | |
| Infer the date from context. Q: Today, 8/3/1997, is a day that we will never forget. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 03/27/1998 (B) 09/02/1997 (C) 07/27/1997 (D) 06/29/1997 (E) 07/27/1973 (F) 12/27/1997 A: | |
| Infer the date from context. Q: May 6, 1992 is like yesterday to Jane, but that is actually ten years ago. What is the date tomorrow in MM/DD/YYYY? Options: (A) 04/16/2002 (B) 04/07/2003 (C) 05/07/2036 (D) 05/28/2002 (E) 05/07/2002 A: | |
| Infer the date from context. Q: Today is the second day of the third month of 1966. What is the date one week ago from today in MM/DD/YYYY? Options: (A) 02/26/1966 (B) 01/13/1966 (C) 02/02/1966 (D) 10/23/1966 (E) 02/23/1968 (F) 02/23/1966 A: | |
| '''.strip()) | |
| txt_output = st.text_area('Examples Outputs (One Line One Output)', ''' | |
| (C) | |
| (E) | |
| (F) | |
| '''.strip()) | |
| max_step = st.slider('Maximum iteration step', 10, 1000, step=10) | |
| # st.subheader("Watch the logs below") | |
| buffer = st.expander("Learning Logs") | |
| if st.button(':rocket: Start!'): | |
| if len(st.session_state["select_names"]) == 0: | |
| st.error("Please select at least 1 module!") | |
| elif max_step < len(st.session_state["select_names"]): | |
| st.error( | |
| "Please specify a larger maximum iteration step than the number of selected modules!") | |
| else: | |
| buffer.text("* begin to perform lorahub learning *") | |
| from util import lorahub_learning | |
| with rd.stderr(to=buffer): | |
| recommendation, final_lora = lorahub_learning([LORA_HUB_NAMES[i] for i in st.session_state["select_names"]], | |
| txt_input, txt_output, max_inference_step=max_step) | |
| st.success("Lorahub learning finished! You got the following recommendation:") | |
| df = { | |
| "modules": [LORA_HUB_NAMES[i] for i in st.session_state["select_names"]], | |
| "weights": recommendation.value, | |
| } | |
| st.table(df) | |
| # zip the final lora module | |
| torch.save(final_lora, "lora/adapter_model.bin") | |
| # create a zip file | |
| shutil.make_archive("lora_module", 'zip', "lora") | |
| with open("lora_module.zip", "rb") as fp: | |
| btn = st.download_button( | |
| label="Download ZIP", | |
| data=fp, | |
| file_name="lora_module.zip", | |
| mime="application/zip" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |