Update app.py
Browse files
app.py
CHANGED
|
@@ -74,7 +74,7 @@ def initialize_gpu():
|
|
| 74 |
pass
|
| 75 |
|
| 76 |
|
| 77 |
-
def reset_model(model_name
|
| 78 |
# extract model info
|
| 79 |
model_args = deepcopy(model_info[model_name])
|
| 80 |
model_path = model_args.pop('model_path')
|
|
@@ -90,8 +90,6 @@ def reset_model(model_name, return_state=False):
|
|
| 90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
| 91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 92 |
gc.collect()
|
| 93 |
-
if return_state:
|
| 94 |
-
return global_state
|
| 95 |
|
| 96 |
|
| 97 |
def get_hidden_states(raw_original_prompt):
|
|
@@ -145,11 +143,13 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
| 145 |
|
| 146 |
## main
|
| 147 |
torch.set_grad_enabled(False)
|
|
|
|
| 148 |
|
| 149 |
model_name = 'LLAMA2-7B'
|
| 150 |
-
|
| 151 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
| 152 |
tokens_container = []
|
|
|
|
| 153 |
for i in range(MAX_PROMPT_TOKENS):
|
| 154 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
| 155 |
tokens_container.append(btn)
|
|
|
|
| 74 |
pass
|
| 75 |
|
| 76 |
|
| 77 |
+
def reset_model(model_name):
|
| 78 |
# extract model info
|
| 79 |
model_args = deepcopy(model_info[model_name])
|
| 80 |
model_path = model_args.pop('model_path')
|
|
|
|
| 90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
| 91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 92 |
gc.collect()
|
|
|
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
def get_hidden_states(raw_original_prompt):
|
|
|
|
| 143 |
|
| 144 |
## main
|
| 145 |
torch.set_grad_enabled(False)
|
| 146 |
+
global_state = GlobalState()
|
| 147 |
|
| 148 |
model_name = 'LLAMA2-7B'
|
| 149 |
+
reset_model(model_name)
|
| 150 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
| 151 |
tokens_container = []
|
| 152 |
+
|
| 153 |
for i in range(MAX_PROMPT_TOKENS):
|
| 154 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
| 155 |
tokens_container.append(btn)
|