Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -76,16 +76,20 @@ def semantic_search(q: str, top_k: int=100):
|
|
| 76 |
D, I = faiss_index.search(emb, top_k)
|
| 77 |
return [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
|
| 78 |
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
preds = set()
|
|
|
|
| 81 |
for tok, mdl in ((tok1,m1),(tok2,m2)):
|
| 82 |
-
enc = tok(
|
| 83 |
with torch.no_grad():
|
| 84 |
logits = mdl(**enc).logits
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
preds
|
|
|
|
| 89 |
return preds
|
| 90 |
|
| 91 |
# βββ UI HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -169,12 +173,16 @@ def display_chat_history(task: str):
|
|
| 169 |
def process_query(task: str, query: str):
|
| 170 |
top_k, top_n = 100, 30
|
| 171 |
if task=="search_etf":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
with st.spinner("Searching ETFs..."):
|
| 173 |
fetched = semantic_search(query, top_k)
|
| 174 |
df_out = modules.get_etf_recommendations_from_list(
|
| 175 |
fetched, df_etf, top_n
|
| 176 |
)
|
| 177 |
-
msg = f"{len(df_out)} ETFs found."
|
| 178 |
st.session_state[f"all_chat_history_{task}"].append(
|
| 179 |
modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
|
| 180 |
)
|
|
|
|
| 76 |
D, I = faiss_index.search(emb, top_k)
|
| 77 |
return [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
|
| 78 |
|
| 79 |
+
|
| 80 |
+
# Ensemble function: union of both models' predictions
|
| 81 |
+
def ensemble_ticker_extraction(query):
|
| 82 |
preds = set()
|
| 83 |
+
|
| 84 |
for tok, mdl in ((tok1,m1),(tok2,m2)):
|
| 85 |
+
enc = tok(query, return_tensors="pt")
|
| 86 |
with torch.no_grad():
|
| 87 |
logits = mdl(**enc).logits
|
| 88 |
+
pred_ids = logits.argmax(dim=-1)[0].tolist()
|
| 89 |
+
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
|
| 90 |
+
labels = [mdl.config.id2label[i] for i in pred_ids]
|
| 91 |
+
preds.update(modules.extract_valid_tickers(tokens, labels, tok, valid_ticker_set))
|
| 92 |
+
|
| 93 |
return preds
|
| 94 |
|
| 95 |
# βββ UI HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 173 |
def process_query(task: str, query: str):
|
| 174 |
top_k, top_n = 100, 30
|
| 175 |
if task=="search_etf":
|
| 176 |
+
"""Process a query by calling your LLM and appending results."""
|
| 177 |
+
# st.session_state["all_chat_history"].append(app_helper_utilities.form_d_chat_history(result_id=str(uuid.uuid4()), response=query_text, task='user_query'))
|
| 178 |
+
# st.session_state["explore_conversation"].append({"role": "user", "content": query_text})
|
| 179 |
+
st.chat_message("user").write(query)
|
| 180 |
with st.spinner("Searching ETFs..."):
|
| 181 |
fetched = semantic_search(query, top_k)
|
| 182 |
df_out = modules.get_etf_recommendations_from_list(
|
| 183 |
fetched, df_etf, top_n
|
| 184 |
)
|
| 185 |
+
msg = f"{len(df_out)} {fetched} ETFs found."
|
| 186 |
st.session_state[f"all_chat_history_{task}"].append(
|
| 187 |
modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
|
| 188 |
)
|