Spaces:
Running
on
Zero
Running
on
Zero
Florian
commited on
Commit
·
af6c532
1
Parent(s):
5b2e6a5
remove penalty alpha and just put the model
Browse files
app.py
CHANGED
|
@@ -25,23 +25,8 @@ def reset():
|
|
| 25 |
_, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
|
| 26 |
|
| 27 |
@st.cache_resource
|
| 28 |
-
def load_model(
|
| 29 |
-
|
| 30 |
-
0.5:"model_20240118-192548.bin",
|
| 31 |
-
2:"model_20240118-211943.bin",
|
| 32 |
-
5:"model_20240118-231333.bin",
|
| 33 |
-
10:"model_20240119-010725.bin",
|
| 34 |
-
20:"model_20240119-030115.bin",
|
| 35 |
-
0:"model_20240119-135506.bin",
|
| 36 |
-
1:"model_20240119-154900.bin",
|
| 37 |
-
-20: "model_20240208-072350.bin",
|
| 38 |
-
-10: "model_20240208-052958.bin",
|
| 39 |
-
-5: "model_20240208-033606.bin",
|
| 40 |
-
-2: "model_20240208-014211.bin",
|
| 41 |
-
-1: "model_20240207-234817.bin",
|
| 42 |
-
-0.5: "model_20240207-215423.bin",
|
| 43 |
-
-0.1: "model_20240207-200020.bin"}
|
| 44 |
-
|
| 45 |
model_str = "susnato/phi-1_5_dev"
|
| 46 |
model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
|
| 47 |
tokenizer = AutoTokenizer.from_pretrained(model_str)
|
|
@@ -49,19 +34,15 @@ def load_model(penalty_alpha):
|
|
| 49 |
branch_locations = list(range(0, 23, 5))
|
| 50 |
model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
|
| 51 |
|
| 52 |
-
# Load the specific model
|
| 53 |
-
model_path =
|
| 54 |
-
if model_path:
|
| 55 |
-
model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
|
| 56 |
-
else:
|
| 57 |
-
print("Invalid penalty_alpha. Using default model weights.")
|
| 58 |
|
| 59 |
return model, tokenizer
|
| 60 |
|
| 61 |
|
| 62 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
| 63 |
print("Loading model...")
|
| 64 |
-
st.session_state.model, st.session_state.tokenizer = load_model(
|
| 65 |
st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
|
| 66 |
print(f"Head number: {st.session_state['head_number']}")
|
| 67 |
# Session state to store the current sentence
|
|
|
|
| 25 |
_, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
|
| 26 |
|
| 27 |
@st.cache_resource
|
| 28 |
+
def load_model(model_path):
|
| 29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
model_str = "susnato/phi-1_5_dev"
|
| 31 |
model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
|
| 32 |
tokenizer = AutoTokenizer.from_pretrained(model_str)
|
|
|
|
| 34 |
branch_locations = list(range(0, 23, 5))
|
| 35 |
model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
|
| 36 |
|
| 37 |
+
# Load the specific model
|
| 38 |
+
model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
return model, tokenizer
|
| 41 |
|
| 42 |
|
| 43 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
| 44 |
print("Loading model...")
|
| 45 |
+
st.session_state.model, st.session_state.tokenizer = load_model("model/model.bin")
|
| 46 |
st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
|
| 47 |
print(f"Head number: {st.session_state['head_number']}")
|
| 48 |
# Session state to store the current sentence
|