amaherovskyi commited on
Commit
a4bfffb
·
verified ·
1 Parent(s): 5b933aa

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1 -12
pipeline.py CHANGED
@@ -1,4 +1,3 @@
1
- # pipeline.py
2
  import os
3
  import logging
4
  from typing import List, Dict
@@ -21,9 +20,7 @@ console_handler.setFormatter(formatter)
21
  logger.addHandler(console_handler)
22
 
23
 
24
- # ---------------------------
25
  # Initialization
26
- # ---------------------------
27
  def init_reranker(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> CrossEncoder:
28
  """Initialize CrossEncoder for document reordering."""
29
  logger.info(f"CrossEncoder initialization: {model_name}")
@@ -39,9 +36,7 @@ def init_groq(api_key: str = None) -> Groq:
39
  return client
40
 
41
 
42
- # ---------------------------
43
  # Reranking
44
- # ---------------------------
45
  def rerank(query: str, docs: List[Dict], reranker: CrossEncoder, top_k: int = 5) -> List[Dict]:
46
  """Reranking documents using CrossEncoder based on a query."""
47
  if not docs:
@@ -57,9 +52,7 @@ def rerank(query: str, docs: List[Dict], reranker: CrossEncoder, top_k: int = 5)
57
  return ranked[:top_k]
58
 
59
 
60
- # ---------------------------
61
  # LLM answering
62
- # ---------------------------
63
  def llm_answer(query: str, context: List[Dict], client: Groq) -> str:
64
  """Forming an LLM response based on the provided document context."""
65
  context_text = "\n\n---\n\n".join(f"[{d['id']}] {d['text']}" for d in context)
@@ -84,9 +77,7 @@ Answer only using information from the context. If answer not found, say "I don'
84
  return completion.choices[0].message.content
85
 
86
 
87
- # ---------------------------
88
- # Retrieve documents (fixed!)
89
- # ---------------------------
90
  def retrieve_documents(
91
  query: str,
92
  documents: list,
@@ -123,9 +114,7 @@ def retrieve_documents(
123
  return docs
124
 
125
 
126
- # ---------------------------
127
  # Full RAG Pipeline
128
- # ---------------------------
129
  def rag_pipeline(
130
  query: str,
131
  reranker_model: CrossEncoder,
 
 
1
  import os
2
  import logging
3
  from typing import List, Dict
 
20
  logger.addHandler(console_handler)
21
 
22
 
 
23
  # Initialization
 
24
  def init_reranker(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> CrossEncoder:
25
  """Initialize CrossEncoder for document reordering."""
26
  logger.info(f"CrossEncoder initialization: {model_name}")
 
36
  return client
37
 
38
 
 
39
  # Reranking
 
40
  def rerank(query: str, docs: List[Dict], reranker: CrossEncoder, top_k: int = 5) -> List[Dict]:
41
  """Reranking documents using CrossEncoder based on a query."""
42
  if not docs:
 
52
  return ranked[:top_k]
53
 
54
 
 
55
  # LLM answering
 
56
  def llm_answer(query: str, context: List[Dict], client: Groq) -> str:
57
  """Forming an LLM response based on the provided document context."""
58
  context_text = "\n\n---\n\n".join(f"[{d['id']}] {d['text']}" for d in context)
 
77
  return completion.choices[0].message.content
78
 
79
 
80
+ # Retrieve documents
 
 
81
  def retrieve_documents(
82
  query: str,
83
  documents: list,
 
114
  return docs
115
 
116
 
 
117
  # Full RAG Pipeline
 
118
  def rag_pipeline(
119
  query: str,
120
  reranker_model: CrossEncoder,