rohannsinghal commited on
Commit
46ea3d7
·
1 Parent(s): 0967030

fix: orrect .gitignore and removed b from tracking

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app/main_api.py +16 -16
.gitignore CHANGED
@@ -18,6 +18,8 @@ unstructuredenv/
18
  app/documents/
19
  app/chroma_db/
20
  _fast_parsed_output.json
 
 
21
 
22
  # IDE and OS files
23
  .vscode/
 
18
  app/documents/
19
  app/chroma_db/
20
  _fast_parsed_output.json
21
+ chroma_db/
22
+
23
 
24
  # IDE and OS files
25
  .vscode/
app/main_api.py CHANGED
@@ -80,15 +80,15 @@ class Answer(BaseModel):
80
  class SubmissionResponse(BaseModel):
81
  answers: List[Answer]
82
 
83
- # RAG Pipeline Class
84
  class RAGPipeline:
85
- def __init__(self, collection_name: str):
 
86
  self.collection_name = collection_name
87
- self.collection = chroma_client.get_or_create_collection(name=self.collection_name)
88
-
89
- # In app/main_api.py, inside the RAGPipeline class
90
-
91
- # In app/main_api.py, inside the RAGPipeline class
92
 
93
  def add_documents(self, chunks: List[Dict]):
94
  if not chunks:
@@ -97,13 +97,10 @@ class RAGPipeline:
97
 
98
  logger.info(f"Starting to add {len(chunks)} chunks...")
99
 
100
- # --- START OF FIX ---
101
- # The 'chunks' variable is a list of dictionaries. This code correctly
102
- # uses dictionary key access `c['key']` to get the data.
103
  contents = [c['content'] for c in chunks]
104
  metadatas = [c['metadata'] for c in chunks]
105
  ids = [c['chunk_id'] for c in chunks]
106
- # --- END OF FIX ---
107
 
108
  self.collection.add(
109
  embeddings=self.embedding_model.encode(contents, show_progress_bar=True).tolist(),
@@ -115,8 +112,10 @@ class RAGPipeline:
115
 
116
  def query_documents(self, query: str, n_results: int = 5) -> List[Dict]:
117
  if not self.collection.count(): return []
 
118
  results = self.collection.query(
119
- query_embeddings=embedding_model.encode([query]).tolist(),
 
120
  n_results=min(n_results, self.collection.count()),
121
  include=["documents", "metadatas"]
122
  )
@@ -128,12 +127,13 @@ class RAGPipeline:
128
  user_prompt = f"REFERENCE TEXT:\n{context}\n\nQUESTION: {query}"
129
 
130
  try:
131
- # --- API KEY ROTATION ---
132
- groq_client.api_key = get_next_api_key()
133
- logger.info(f"Using Groq API key ending in ...{groq_client.api_key[-4:]}")
134
 
135
  response = await asyncio.to_thread(
136
- groq_client.chat.completions.create,
 
137
  model="llama3-8b-8192",
138
  messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
139
  temperature=0.0,
 
80
  class SubmissionResponse(BaseModel):
81
  answers: List[Answer]
82
 
 
83
  class RAGPipeline:
84
+ def __init__(self, collection_name: str, request: Request):
85
+ # --- FIX: Get models and clients from the app state via the request ---
86
  self.collection_name = collection_name
87
+ self.request = request
88
+ self.chroma_client = request.app.state.chroma_client
89
+ self.embedding_model = request.app.state.embedding_model
90
+ self.groq_client = request.app.state.groq_client
91
+ self.collection = self.chroma_client.get_or_create_collection(name=self.collection_name)
92
 
93
  def add_documents(self, chunks: List[Dict]):
94
  if not chunks:
 
97
 
98
  logger.info(f"Starting to add {len(chunks)} chunks...")
99
 
100
+ # Use instance variables to access models
 
 
101
  contents = [c['content'] for c in chunks]
102
  metadatas = [c['metadata'] for c in chunks]
103
  ids = [c['chunk_id'] for c in chunks]
 
104
 
105
  self.collection.add(
106
  embeddings=self.embedding_model.encode(contents, show_progress_bar=True).tolist(),
 
112
 
113
  def query_documents(self, query: str, n_results: int = 5) -> List[Dict]:
114
  if not self.collection.count(): return []
115
+
116
  results = self.collection.query(
117
+ # Use instance variable for the model
118
+ query_embeddings=self.embedding_model.encode([query]).tolist(),
119
  n_results=min(n_results, self.collection.count()),
120
  include=["documents", "metadatas"]
121
  )
 
127
  user_prompt = f"REFERENCE TEXT:\n{context}\n\nQUESTION: {query}"
128
 
129
  try:
130
+ # --- FIX: Access the API key cycler from the app state ---
131
+ self.groq_client.api_key = next(self.request.app.state.api_key_cycler)
132
+ logger.info(f"Using Groq API key ending in ...{self.groq_client.api_key[-4:]}")
133
 
134
  response = await asyncio.to_thread(
135
+ # Use instance variable for the client
136
+ self.groq_client.chat.completions.create,
137
  model="llama3-8b-8192",
138
  messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
139
  temperature=0.0,