hskwon7 commited on
Commit
15444a5
Β·
verified Β·
1 Parent(s): 0ff0878

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -76,16 +76,20 @@ def semantic_search(q: str, top_k: int=100):
76
  D, I = faiss_index.search(emb, top_k)
77
  return [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
78
 
79
- def ensemble_ticker_extraction(q: str):
 
 
80
  preds = set()
 
81
  for tok, mdl in ((tok1,m1),(tok2,m2)):
82
- enc = tok(q, return_tensors="pt")
83
  with torch.no_grad():
84
  logits = mdl(**enc).logits
85
- ids = logits.argmax(dim=-1)[0].tolist()
86
- toks = tok.convert_ids_to_tokens(enc["input_ids"][0])
87
- labs = [mdl.config.id2label[i] for i in ids]
88
- preds |= modules.extract_valid_tickers(toks, labs, tok, valid_ticker_set)
 
89
  return preds
90
 
91
  # ─── UI HELPERS ─────────────────────────────────────────────────────────────
@@ -169,12 +173,16 @@ def display_chat_history(task: str):
169
  def process_query(task: str, query: str):
170
  top_k, top_n = 100, 30
171
  if task=="search_etf":
 
 
 
 
172
  with st.spinner("Searching ETFs..."):
173
  fetched = semantic_search(query, top_k)
174
  df_out = modules.get_etf_recommendations_from_list(
175
  fetched, df_etf, top_n
176
  )
177
- msg = f"{len(df_out)} ETFs found."
178
  st.session_state[f"all_chat_history_{task}"].append(
179
  modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
180
  )
 
76
  D, I = faiss_index.search(emb, top_k)
77
  return [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
78
 
79
+
80
+ # Ensemble function: union of both models' predictions
81
+ def ensemble_ticker_extraction(query):
82
  preds = set()
83
+
84
  for tok, mdl in ((tok1,m1),(tok2,m2)):
85
+ enc = tok(query, return_tensors="pt")
86
  with torch.no_grad():
87
  logits = mdl(**enc).logits
88
+ pred_ids = logits.argmax(dim=-1)[0].tolist()
89
+ tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
90
+ labels = [mdl.config.id2label[i] for i in pred_ids]
91
+ preds.update(modules.extract_valid_tickers(tokens, labels, tok, valid_ticker_set))
92
+
93
  return preds
94
 
95
  # ─── UI HELPERS ─────────────────────────────────────────────────────────────
 
173
  def process_query(task: str, query: str):
174
  top_k, top_n = 100, 30
175
  if task=="search_etf":
176
+ """Process a query by calling your LLM and appending results."""
177
+ # st.session_state["all_chat_history"].append(app_helper_utilities.form_d_chat_history(result_id=str(uuid.uuid4()), response=query_text, task='user_query'))
178
+ # st.session_state["explore_conversation"].append({"role": "user", "content": query_text})
179
+ st.chat_message("user").write(query)
180
  with st.spinner("Searching ETFs..."):
181
  fetched = semantic_search(query, top_k)
182
  df_out = modules.get_etf_recommendations_from_list(
183
  fetched, df_etf, top_n
184
  )
185
+ msg = f"{len(df_out)} {fetched} ETFs found."
186
  st.session_state[f"all_chat_history_{task}"].append(
187
  modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
188
  )