Commit
·
fd97c8c
1
Parent(s):
2b27faa
add 3B model
Browse files- app.py +20 -9
- utils/llama_utils.py +5 -5
app.py
CHANGED
|
@@ -35,6 +35,7 @@ st.markdown("""
|
|
| 35 |
# ---------------------------------------
|
| 36 |
base_path = "data/"
|
| 37 |
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
|
| 38 |
adapter_path = "./LLaMA-TOMMI-1.0/"
|
| 39 |
|
| 40 |
st.title(":red[AI University] :gray[/] FEM")
|
|
@@ -115,12 +116,12 @@ with st.sidebar:
|
|
| 115 |
# Choose the LLM model
|
| 116 |
st.session_state.synthesis_model = st.selectbox(
|
| 117 |
"Choose the LLM model",
|
| 118 |
-
["LLaMA-3.2-
|
| 119 |
index=1,
|
| 120 |
key='a2model'
|
| 121 |
)
|
| 122 |
|
| 123 |
-
if st.session_state.synthesis_model
|
| 124 |
synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample')
|
| 125 |
|
| 126 |
if synthesis_do_sample:
|
|
@@ -169,6 +170,14 @@ with col2:
|
|
| 169 |
help=question_help
|
| 170 |
)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
with st.spinner("Loading LLaMA-3.2-11B..."):
|
| 173 |
if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
|
| 174 |
if 'llama_model' not in st.session_state:
|
|
@@ -176,12 +185,12 @@ with st.spinner("Loading LLaMA-3.2-11B..."):
|
|
| 176 |
st.session_state.llama_model = llama_model
|
| 177 |
st.session_state.llama_tokenizer = llama_tokenizer
|
| 178 |
|
| 179 |
-
with st.spinner("Loading LLaMA-
|
| 180 |
-
if st.session_state.expert_model
|
| 181 |
-
if '
|
| 182 |
-
|
| 183 |
-
st.session_state.
|
| 184 |
-
st.session_state.
|
| 185 |
|
| 186 |
# Load YouTube and LaTeX data
|
| 187 |
text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens)
|
|
@@ -264,6 +273,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
|
|
| 264 |
model=model_,
|
| 265 |
tokenizer=tokenizer_,
|
| 266 |
messages=messages,
|
|
|
|
| 267 |
do_sample=expert_do_sample,
|
| 268 |
temperature=expert_temperature if expert_do_sample else None,
|
| 269 |
top_k=expert_top_k if expert_do_sample else None,
|
|
@@ -289,7 +299,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
|
|
| 289 |
#-------------------------
|
| 290 |
# synthesis responses
|
| 291 |
#-------------------------
|
| 292 |
-
if st.session_state.synthesis_model
|
| 293 |
synthesis_prompt = f"""
|
| 294 |
Question:
|
| 295 |
{st.session_state.question}
|
|
@@ -311,6 +321,7 @@ if submit_button_placeholder.button("AI Answer", type="primary"):
|
|
| 311 |
model=st.session_state.llama_model,
|
| 312 |
tokenizer=st.session_state.llama_tokenizer,
|
| 313 |
messages=messages,
|
|
|
|
| 314 |
do_sample=synthesis_do_sample,
|
| 315 |
temperature=synthesis_temperature if synthesis_do_sample else None,
|
| 316 |
top_k=synthesis_top_k if synthesis_do_sample else None,
|
|
|
|
| 35 |
# ---------------------------------------
|
| 36 |
base_path = "data/"
|
| 37 |
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
| 38 |
+
base_model_path_3B = "meta-llama/Llama-3.2-3B-Instruct"
|
| 39 |
adapter_path = "./LLaMA-TOMMI-1.0/"
|
| 40 |
|
| 41 |
st.title(":red[AI University] :gray[/] FEM")
|
|
|
|
| 116 |
# Choose the LLM model
|
| 117 |
st.session_state.synthesis_model = st.selectbox(
|
| 118 |
"Choose the LLM model",
|
| 119 |
+
["LLaMA-3.2-3B","gpt-4o-mini"], # "LLaMA-3.2-11B",
|
| 120 |
index=1,
|
| 121 |
key='a2model'
|
| 122 |
)
|
| 123 |
|
| 124 |
+
if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]:
|
| 125 |
synthesis_do_sample = st.toggle("Enable Sampling", value=False, key='synthesis_sample')
|
| 126 |
|
| 127 |
if synthesis_do_sample:
|
|
|
|
| 170 |
help=question_help
|
| 171 |
)
|
| 172 |
|
| 173 |
+
with st.spinner("Loading LLaMA-TOMMI-1.0-11B..."):
|
| 174 |
+
if st.session_state.expert_model == "LLaMA-TOMMI-1.0-11B":
|
| 175 |
+
if 'tommi_model' not in st.session_state:
|
| 176 |
+
tommi_model, tommi_tokenizer = load_fine_tuned_model(adapter_path, base_model_path)
|
| 177 |
+
st.session_state.tommi_model = tommi_model
|
| 178 |
+
st.session_state.tommi_tokenizer = tommi_tokenizer
|
| 179 |
+
|
| 180 |
+
|
| 181 |
with st.spinner("Loading LLaMA-3.2-11B..."):
|
| 182 |
if "LLaMA-3.2-11B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
|
| 183 |
if 'llama_model' not in st.session_state:
|
|
|
|
| 185 |
st.session_state.llama_model = llama_model
|
| 186 |
st.session_state.llama_tokenizer = llama_tokenizer
|
| 187 |
|
| 188 |
+
with st.spinner("Loading LLaMA-3.2-3B..."):
|
| 189 |
+
if "LLaMA-3.2-3B" in [st.session_state.expert_model, st.session_state.synthesis_model]:
|
| 190 |
+
if 'llama_model_3B' not in st.session_state:
|
| 191 |
+
llama_model_3B, llama_tokenizer_3B = load_base_model(base_model_path_3B)
|
| 192 |
+
st.session_state.llama_model_3B = llama_model_3B
|
| 193 |
+
st.session_state.llama_tokenizer_3B = llama_tokenizer_3B
|
| 194 |
|
| 195 |
# Load YouTube and LaTeX data
|
| 196 |
text_data_YT, context_embeddings_YT = load_youtube_data(base_path, model_name, yt_chunk_tokens, yt_overlap_tokens)
|
|
|
|
| 273 |
model=model_,
|
| 274 |
tokenizer=tokenizer_,
|
| 275 |
messages=messages,
|
| 276 |
+
tokenizer_max_length=500,
|
| 277 |
do_sample=expert_do_sample,
|
| 278 |
temperature=expert_temperature if expert_do_sample else None,
|
| 279 |
top_k=expert_top_k if expert_do_sample else None,
|
|
|
|
| 299 |
#-------------------------
|
| 300 |
# synthesis responses
|
| 301 |
#-------------------------
|
| 302 |
+
if st.session_state.synthesis_model in ["LLaMA-3.2-3B", "LLaMA-3.2-11B"]:
|
| 303 |
synthesis_prompt = f"""
|
| 304 |
Question:
|
| 305 |
{st.session_state.question}
|
|
|
|
| 321 |
model=st.session_state.llama_model,
|
| 322 |
tokenizer=st.session_state.llama_tokenizer,
|
| 323 |
messages=messages,
|
| 324 |
+
tokenizer_max_length=30000,
|
| 325 |
do_sample=synthesis_do_sample,
|
| 326 |
temperature=synthesis_temperature if synthesis_do_sample else None,
|
| 327 |
top_k=synthesis_top_k if synthesis_do_sample else None,
|
utils/llama_utils.py
CHANGED
|
@@ -93,16 +93,16 @@ def generate_response(
|
|
| 93 |
model: AutoModelForCausalLM,
|
| 94 |
tokenizer: PreTrainedTokenizerFast,
|
| 95 |
messages: list,
|
|
|
|
| 96 |
do_sample: bool = False,
|
| 97 |
-
temperature: float = 0.
|
| 98 |
top_k: int = 50,
|
| 99 |
top_p: float = 0.95,
|
| 100 |
num_beams: int = 1,
|
| 101 |
-
max_new_tokens: int =
|
| 102 |
) -> str:
|
| 103 |
"""
|
| 104 |
Runs inference on an LLM model.
|
| 105 |
-
|
| 106 |
Args:
|
| 107 |
model (AutoModelForCausalLM)
|
| 108 |
tokenizer (PreTrainedTokenizerFast)
|
|
@@ -124,7 +124,7 @@ def generate_response(
|
|
| 124 |
# Tokenize input
|
| 125 |
inputs = tokenizer(
|
| 126 |
input_text,
|
| 127 |
-
max_length=
|
| 128 |
truncation=True,
|
| 129 |
return_tensors="pt"
|
| 130 |
).to(model.device)
|
|
@@ -158,4 +158,4 @@ def generate_response(
|
|
| 158 |
|
| 159 |
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
|
| 160 |
|
| 161 |
-
return response
|
|
|
|
| 93 |
model: AutoModelForCausalLM,
|
| 94 |
tokenizer: PreTrainedTokenizerFast,
|
| 95 |
messages: list,
|
| 96 |
+
tokenizer_max_length: int = 500,
|
| 97 |
do_sample: bool = False,
|
| 98 |
+
temperature: float = 0.1,
|
| 99 |
top_k: int = 50,
|
| 100 |
top_p: float = 0.95,
|
| 101 |
num_beams: int = 1,
|
| 102 |
+
max_new_tokens: int = 700
|
| 103 |
) -> str:
|
| 104 |
"""
|
| 105 |
Runs inference on an LLM model.
|
|
|
|
| 106 |
Args:
|
| 107 |
model (AutoModelForCausalLM)
|
| 108 |
tokenizer (PreTrainedTokenizerFast)
|
|
|
|
| 124 |
# Tokenize input
|
| 125 |
inputs = tokenizer(
|
| 126 |
input_text,
|
| 127 |
+
max_length=tokenizer_max_length,
|
| 128 |
truncation=True,
|
| 129 |
return_tensors="pt"
|
| 130 |
).to(model.device)
|
|
|
|
| 158 |
|
| 159 |
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
|
| 160 |
|
| 161 |
+
return response
|