Update app.py
Browse files
app.py
CHANGED
|
@@ -11,55 +11,12 @@ from datasets import load_dataset
|
|
| 11 |
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
|
| 12 |
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
|
| 13 |
from interpret import InterpretationPrompt
|
|
|
|
|
|
|
| 14 |
|
| 15 |
MAX_PROMPT_TOKENS = 60
|
| 16 |
MAX_NUM_LAYERS = 50
|
| 17 |
|
| 18 |
-
|
| 19 |
-
## info
|
| 20 |
-
dataset_info = [
|
| 21 |
-
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
| 22 |
-
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
|
| 23 |
-
'filter': lambda x: x['label'] == 1},
|
| 24 |
-
# {'name': 'Physical Understanding', 'hf_repo': 'piqa', 'text_col': 'goal'},
|
| 25 |
-
{'name': 'Social Reasoning', 'hf_repo': 'ProlificAI/social-reasoning-rlhf', 'text_col': 'question'}
|
| 26 |
-
]
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
model_info = {
|
| 30 |
-
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'],
|
| 31 |
-
original_prompt_template='<s>{prompt}',
|
| 32 |
-
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
|
| 33 |
-
), # , load_in_8bit=True
|
| 34 |
-
|
| 35 |
-
# 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
|
| 36 |
-
# original_prompt_template='<bos>{prompt}',
|
| 37 |
-
# interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
|
| 38 |
-
# ),
|
| 39 |
-
|
| 40 |
-
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
|
| 41 |
-
original_prompt_template='<s>{prompt}',
|
| 42 |
-
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
|
| 43 |
-
),
|
| 44 |
-
|
| 45 |
-
# 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
|
| 46 |
-
# tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
|
| 47 |
-
# model_type='llama', hf=True, ctransformers=True,
|
| 48 |
-
# original_prompt_template='<s>[INST] {prompt} [/INST]',
|
| 49 |
-
# interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
|
| 50 |
-
# )
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
suggested_interpretation_prompts = [
|
| 55 |
-
"Sure, here's a bullet list of the key words in your message:",
|
| 56 |
-
"Sure, I'll summarize your message:",
|
| 57 |
-
"Sure, here are the words in your message:",
|
| 58 |
-
"Before responding, let me repeat the message you wrote:",
|
| 59 |
-
"Let me repeat the message:"
|
| 60 |
-
]
|
| 61 |
-
|
| 62 |
-
|
| 63 |
@dataclass
|
| 64 |
class GlobalState:
|
| 65 |
tokenizer : Optional[PreTrainedTokenizer] = None
|
|
|
|
| 11 |
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
|
| 12 |
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
|
| 13 |
from interpret import InterpretationPrompt
|
| 14 |
+
from configs import model_info, dataset_info
|
| 15 |
+
|
| 16 |
|
| 17 |
MAX_PROMPT_TOKENS = 60
|
| 18 |
MAX_NUM_LAYERS = 50
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@dataclass
|
| 21 |
class GlobalState:
|
| 22 |
tokenizer : Optional[PreTrainedTokenizer] = None
|