Diwakar Basnet commited on
Commit
166f42d
Β·
1 Parent(s): a98cefd

Updated the file structure

Browse files
app.py CHANGED
@@ -61,8 +61,8 @@ def _get_engine():
61
 
62
  try:
63
  print("Initialising FinRAG engine (first request)...")
64
- from app.llm.fin_rag_engine import FinRAGEngine
65
- from app.data.retrieval.filing_resolver import FilingResolver
66
  _engine = FinRAGEngine()
67
  _resolver = FilingResolver()
68
  print("Engine ready.")
@@ -107,12 +107,12 @@ def extract_filing_id(choice: str) -> Optional[str]:
107
 
108
  def chat(
109
  message: str,
110
- history: List[dict],
111
  mode: str,
112
  filing_choice: str,
113
  top_k: int,
114
  request: gr.Request,
115
- ) -> Tuple[List[dict], str, str]:
116
  """Returns (updated_history, sources_text, cleared_input)."""
117
 
118
  if not message.strip():
@@ -127,16 +127,13 @@ def chat(
127
  f"You have reached the daily limit of **{MAX_QUERIES_PER_DAY} queries**. "
128
  f"Resets at {reset_time.strftime('%H:%M UTC')}."
129
  )
130
- return history + [{"role": "user", "content": message},
131
- {"role": "assistant", "content": msg}], "", ""
132
 
133
  # Load engine
134
  try:
135
  engine, resolver = _get_engine()
136
  except Exception as e:
137
- msg = f"Service unavailable: {e}"
138
- return history + [{"role": "user", "content": message},
139
- {"role": "assistant", "content": msg}], "", ""
140
 
141
  filing_id = extract_filing_id(filing_choice)
142
 
@@ -154,31 +151,24 @@ def chat(
154
  ref = arg or filing_id
155
  if not ref:
156
  reply = "[X] Specify a company: `/risks NVIDIA` or select a filing from the dropdown."
157
- return history + [{"role": "user", "content": message},
158
- {"role": "assistant", "content": reply}], "", ""
159
  resp = engine.summarise_risks(ref)
160
  sources = ", ".join(resp.source_filings) or "N/A"
161
- return (history + [{"role": "user", "content": message},
162
- {"role": "assistant", "content": resp.answer}],
163
- f"Sources: {sources}", "")
164
 
165
  elif cmd == "/financials":
166
  ref = arg or filing_id
167
  if not ref:
168
  reply = "[X] Specify a company: `/financials MSFT` or select a filing."
169
- return history + [{"role": "user", "content": message},
170
- {"role": "assistant", "content": reply}], "", ""
171
  resp = engine.extract_financials(ref)
172
  sources = ", ".join(resp.source_filings) or "N/A"
173
- return (history + [{"role": "user", "content": message},
174
- {"role": "assistant", "content": resp.answer}],
175
- f"Sources: {sources}", "")
176
 
177
  elif cmd == "/compare":
178
  if not arg:
179
  reply = "[X] Usage: `/compare NVIDIA, Microsoft | AI revenue`"
180
- return history + [{"role": "user", "content": message},
181
- {"role": "assistant", "content": reply}], "", ""
182
  if "|" in arg:
183
  companies_raw, topic = arg.split("|", 1)
184
  else:
@@ -187,9 +177,7 @@ def chat(
187
  companies = [c.strip() for c in companies_raw.split(",") if c.strip()]
188
  resp = engine.compare_companies(companies, topic.strip())
189
  sources = ", ".join(resp.source_filings) or "N/A"
190
- return (history + [{"role": "user", "content": message},
191
- {"role": "assistant", "content": resp.answer}],
192
- f"Sources: {sources}", "")
193
 
194
  elif cmd == "/help":
195
  help_text = (
@@ -200,8 +188,7 @@ def chat(
200
  "- `/compare <co1>, <co2> | <topic>` β€” compare companies\n\n"
201
  "Or just ask naturally: *What are Google's main revenue segments?*"
202
  )
203
- return (history + [{"role": "user", "content": message},
204
- {"role": "assistant", "content": help_text}], "", "")
205
 
206
  # ── Normal question ──────────────────────────────────────────────── #
207
  resp = engine.ask(
@@ -213,8 +200,7 @@ def chat(
213
  )
214
  sources = ", ".join(resp.source_filings) if resp.source_filings else "N/A"
215
  return (
216
- history + [{"role": "user", "content": message},
217
- {"role": "assistant", "content": resp.answer}],
218
  f"Sources: {sources} | Mode: {resp.retrieval_mode} | {_usage_text(request)}",
219
  "",
220
  )
@@ -232,8 +218,7 @@ def quick_risks(history, filing, request: gr.Request):
232
  fid = extract_filing_id(filing)
233
  if not fid:
234
  msg = "Please select a filing from the dropdown first."
235
- return (history + [{"role": "user", "content": "/risks"},
236
- {"role": "assistant", "content": msg}], "")
237
  new_h, src, _ = chat(f"/risks {fid}", history, "Local", filing, 10, request)
238
  return new_h, src
239
 
@@ -242,8 +227,7 @@ def quick_financials(history, filing, request: gr.Request):
242
  fid = extract_filing_id(filing)
243
  if not fid:
244
  msg = "Please select a filing from the dropdown first."
245
- return (history + [{"role": "user", "content": "/financials"},
246
- {"role": "assistant", "content": msg}], "")
247
  new_h, src, _ = chat(f"/financials {fid}", history, "Local", filing, 10, request)
248
  return new_h, src
249
 
@@ -395,7 +379,7 @@ COMMANDS_MD = """
395
  def build_ui():
396
  filing_choices = get_filing_choices()
397
 
398
- with gr.Blocks(title="FinSight", css=CSS) as demo:
399
 
400
  gr.HTML("""
401
  <div class="fin-header">
@@ -414,7 +398,6 @@ def build_ui():
414
  show_label=False,
415
  elem_classes=["chatbot-wrap"],
416
  render_markdown=True,
417
- type="messages",
418
  placeholder=(
419
  "<div style='text-align:center;color:#1e2530;"
420
  "font-family:DM Mono,monospace;font-size:12px;padding:60px 20px'>"
@@ -524,4 +507,5 @@ if __name__ == "__main__":
524
  server_name="0.0.0.0",
525
  server_port=7860,
526
  show_error=True,
 
527
  )
 
61
 
62
  try:
63
  print("Initialising FinRAG engine (first request)...")
64
+ from llm.fin_rag_engine import FinRAGEngine
65
+ from data.retrieval.filing_resolver import FilingResolver
66
  _engine = FinRAGEngine()
67
  _resolver = FilingResolver()
68
  print("Engine ready.")
 
107
 
108
  def chat(
109
  message: str,
110
+ history: List[Tuple[str, str]],
111
  mode: str,
112
  filing_choice: str,
113
  top_k: int,
114
  request: gr.Request,
115
+ ) -> Tuple[List[Tuple[str, str]], str, str]:
116
  """Returns (updated_history, sources_text, cleared_input)."""
117
 
118
  if not message.strip():
 
127
  f"You have reached the daily limit of **{MAX_QUERIES_PER_DAY} queries**. "
128
  f"Resets at {reset_time.strftime('%H:%M UTC')}."
129
  )
130
+ return history + [(message, msg)], "", ""
 
131
 
132
  # Load engine
133
  try:
134
  engine, resolver = _get_engine()
135
  except Exception as e:
136
+ return history + [(message, f"Service unavailable: {e}")], "", ""
 
 
137
 
138
  filing_id = extract_filing_id(filing_choice)
139
 
 
151
  ref = arg or filing_id
152
  if not ref:
153
  reply = "[X] Specify a company: `/risks NVIDIA` or select a filing from the dropdown."
154
+ return history + [(message, reply)], "", ""
 
155
  resp = engine.summarise_risks(ref)
156
  sources = ", ".join(resp.source_filings) or "N/A"
157
+ return history + [(message, resp.answer)], f"Sources: {sources}", ""
 
 
158
 
159
  elif cmd == "/financials":
160
  ref = arg or filing_id
161
  if not ref:
162
  reply = "[X] Specify a company: `/financials MSFT` or select a filing."
163
+ return history + [(message, reply)], "", ""
 
164
  resp = engine.extract_financials(ref)
165
  sources = ", ".join(resp.source_filings) or "N/A"
166
+ return history + [(message, resp.answer)], f"Sources: {sources}", ""
 
 
167
 
168
  elif cmd == "/compare":
169
  if not arg:
170
  reply = "[X] Usage: `/compare NVIDIA, Microsoft | AI revenue`"
171
+ return history + [(message, reply)], "", ""
 
172
  if "|" in arg:
173
  companies_raw, topic = arg.split("|", 1)
174
  else:
 
177
  companies = [c.strip() for c in companies_raw.split(",") if c.strip()]
178
  resp = engine.compare_companies(companies, topic.strip())
179
  sources = ", ".join(resp.source_filings) or "N/A"
180
+ return history + [(message, resp.answer)], f"Sources: {sources}", ""
 
 
181
 
182
  elif cmd == "/help":
183
  help_text = (
 
188
  "- `/compare <co1>, <co2> | <topic>` β€” compare companies\n\n"
189
  "Or just ask naturally: *What are Google's main revenue segments?*"
190
  )
191
+ return history + [(message, help_text)], "", ""
 
192
 
193
  # ── Normal question ──────────────────────────────────────────────── #
194
  resp = engine.ask(
 
200
  )
201
  sources = ", ".join(resp.source_filings) if resp.source_filings else "N/A"
202
  return (
203
+ history + [(message, resp.answer)],
 
204
  f"Sources: {sources} | Mode: {resp.retrieval_mode} | {_usage_text(request)}",
205
  "",
206
  )
 
218
  fid = extract_filing_id(filing)
219
  if not fid:
220
  msg = "Please select a filing from the dropdown first."
221
+ return history + [("/risks", msg)], ""
 
222
  new_h, src, _ = chat(f"/risks {fid}", history, "Local", filing, 10, request)
223
  return new_h, src
224
 
 
227
  fid = extract_filing_id(filing)
228
  if not fid:
229
  msg = "Please select a filing from the dropdown first."
230
+ return history + [("/financials", msg)], ""
 
231
  new_h, src, _ = chat(f"/financials {fid}", history, "Local", filing, 10, request)
232
  return new_h, src
233
 
 
379
  def build_ui():
380
  filing_choices = get_filing_choices()
381
 
382
+ with gr.Blocks(title="FinSight") as demo:
383
 
384
  gr.HTML("""
385
  <div class="fin-header">
 
398
  show_label=False,
399
  elem_classes=["chatbot-wrap"],
400
  render_markdown=True,
 
401
  placeholder=(
402
  "<div style='text-align:center;color:#1e2530;"
403
  "font-family:DM Mono,monospace;font-size:12px;padding:60px 20px'>"
 
507
  server_name="0.0.0.0",
508
  server_port=7860,
509
  show_error=True,
510
+ css=CSS,
511
  )
app/config.py β†’ config.py RENAMED
File without changes
{app/data β†’ data}/retrieval/filing_resolver.py RENAMED
@@ -1,5 +1,5 @@
1
  from typing import Optional, List, Dict
2
- from app.data.retrieval.graph_retriever import GraphRetriever
3
 
4
 
5
  class FilingResolver:
 
1
  from typing import Optional, List, Dict
2
+ from data.retrieval.graph_retriever import GraphRetriever
3
 
4
 
5
  class FilingResolver:
{app/data β†’ data}/retrieval/graph_retriever.py RENAMED
@@ -1,7 +1,7 @@
1
  from neo4j import GraphDatabase
2
  from typing import List, Dict, Any, Optional
3
 
4
- from app.config import settings
5
 
6
 
7
  class GraphRetriever:
 
1
  from neo4j import GraphDatabase
2
  from typing import List, Dict, Any, Optional
3
 
4
+ from config import settings
5
 
6
 
7
  class GraphRetriever:
{app/data β†’ data}/retrieval/hybridrag_retriever.py RENAMED
@@ -1,10 +1,10 @@
1
  from dataclasses import dataclass, field
2
  from typing import List, Dict, Any, Optional
3
 
4
- from app.config import settings
5
- from app.data.retrieval.reranker import Reranker
6
- from app.data.retrieval.graph_retriever import GraphRetriever
7
- from app.data.retrieval.weaviate_retriever import WeaviateRetriever
8
 
9
 
10
  @dataclass
 
1
  from dataclasses import dataclass, field
2
  from typing import List, Dict, Any, Optional
3
 
4
+ from config import settings
5
+ from data.retrieval.reranker import Reranker
6
+ from data.retrieval.graph_retriever import GraphRetriever
7
+ from data.retrieval.weaviate_retriever import WeaviateRetriever
8
 
9
 
10
  @dataclass
{app/data β†’ data}/retrieval/reranker.py RENAMED
@@ -17,7 +17,7 @@ class Reranker:
17
 
18
  if self.use_cross_encoder:
19
  try:
20
- from app.utils.reranker_utils import NimReranker
21
  self._ce_client = NimReranker()
22
  except Exception:
23
  self.use_cross_encoder = False
 
17
 
18
  if self.use_cross_encoder:
19
  try:
20
+ from utils.reranker_utils import NimReranker
21
  self._ce_client = NimReranker()
22
  except Exception:
23
  self.use_cross_encoder = False
{app/data β†’ data}/retrieval/weaviate_retriever.py RENAMED
@@ -3,8 +3,8 @@ from weaviate.classes.init import Auth
3
  from weaviate.classes.query import Filter, MetadataQuery
4
  from typing import List, Dict, Any, Optional
5
 
6
- from app.config import settings
7
- from app.utils.embedding_utils import BGEM3Embedder
8
 
9
 
10
  class WeaviateRetriever:
 
3
  from weaviate.classes.query import Filter, MetadataQuery
4
  from typing import List, Dict, Any, Optional
5
 
6
+ from config import settings
7
+ from utils.embedding_utils import BGEM3Embedder
8
 
9
 
10
  class WeaviateRetriever:
{app/llm β†’ llm}/__init__.py RENAMED
File without changes
{app/llm β†’ llm}/fin_rag_engine.py RENAMED
@@ -1,9 +1,9 @@
1
  from dataclasses import dataclass, field
2
  from typing import List, Dict, Any, Optional, Iterator
3
 
4
- from app.llm.groq_client import GroqClient
5
- from app.data.retrieval.filing_resolver import FilingResolver
6
- from app.data.retrieval.hybridrag_retriever import HybridRAGRetriever, RetrievedContext
7
 
8
 
9
  @dataclass
 
1
  from dataclasses import dataclass, field
2
  from typing import List, Dict, Any, Optional, Iterator
3
 
4
+ from llm.groq_client import GroqClient
5
+ from data.retrieval.filing_resolver import FilingResolver
6
+ from data.retrieval.hybridrag_retriever import HybridRAGRetriever, RetrievedContext
7
 
8
 
9
  @dataclass
{app/llm β†’ llm}/groq_client.py RENAMED
@@ -2,7 +2,7 @@ from openai import OpenAI
2
  from typing import List, Dict, Any, Optional, Iterator
3
 
4
  from groq import Groq
5
- from app.config import settings
6
 
7
 
8
  class GroqClient:
 
2
  from typing import List, Dict, Any, Optional, Iterator
3
 
4
  from groq import Groq
5
+ from config import settings
6
 
7
 
8
  class GroqClient:
utils/embedding_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from config import settings
3
+
4
+
5
+ class BGEM3Embedder:
6
+ """Wraps the NVIDIA serverless BGE-M3 endpoint."""
7
+ MODEL = "baai/bge-m3"
8
+
9
+ def __init__(self):
10
+ self.client = OpenAI(
11
+ api_key=settings.NVIDIA_NIM_API,
12
+ base_url="https://integrate.api.nvidia.com/v1",
13
+ )
14
+
15
+ def embed(self, text: str) -> List[float]:
16
+ response = self.client.embeddings.create(
17
+ input=[text],
18
+ model=self.MODEL,
19
+ encoding_format="float",
20
+ extra_body={"truncate": "END"}, # truncate instead of error on long text
21
+ )
22
+ return response.data[0].embedding
23
+
24
+ def embed_many(self, texts: List[str]) -> List[List[float]]:
25
+ response = self.client.embeddings.create(
26
+ input=texts,
27
+ model=self.MODEL,
28
+ encoding_format="float",
29
+ extra_body={"truncate": "END"},
30
+ )
31
+ return [d.embedding for d in response.data]
utils/reranker_utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import List, Dict, Any
3
+ from config import settings
4
+
5
+
6
+ class NimReranker:
7
+ """Nvidia NIM Reranker."""
8
+
9
+ MODEL = "nv-rerank-qa-mistral-4b:1"
10
+ INVOKE_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking"
11
+
12
+ def __init__(self):
13
+ self.session = requests.Session()
14
+ self.headers = {
15
+ "Authorization": f"Bearer {settings.NVIDIA_NIM_API}",
16
+ "Accept": "application/json",
17
+ }
18
+
19
+ def rerank_run(self, query: str, passages: List[str]) -> List[Dict[str, Any]]:
20
+ """
21
+ Rerank a list of passages for a given query using Nvidia NIM.
22
+
23
+ Args:
24
+ query: The question or query string.
25
+ passages: A list of chunk strings to rerank.
26
+
27
+ Returns:
28
+ A list of dictionaries containing the text, the ranking score (logit),
29
+ and the original index, sorted by score in descending order.
30
+ """
31
+ if not passages:
32
+ return []
33
+
34
+ payload = {
35
+ "model": self.MODEL,
36
+ "query": {"text": query},
37
+ "passages": [{"text": p} for p in passages]
38
+ }
39
+
40
+ response = self.session.post(self.INVOKE_URL, headers=self.headers, json=payload)
41
+ response.raise_for_status()
42
+
43
+ data = response.json()
44
+ rankings = data.get("rankings", [])
45
+
46
+ results = []
47
+ for item in rankings:
48
+ idx = item["index"]
49
+ results.append({
50
+ "text": passages[idx],
51
+ "score": item["logit"],
52
+ "index": idx
53
+ })
54
+
55
+ # Sort by score descending
56
+ results.sort(key=lambda x: x["score"], reverse=True)
57
+
58
+ return results