viboognesh commited on
Commit
fefc214
·
verified ·
1 Parent(s): 84015e9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -21
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, File, UploadFile, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
 
4
  from typing import List, Dict, Any
@@ -202,30 +202,39 @@ app.state.conversation_chain = None
202
  @app.post("/upload_files/")
203
  async def upload_files(files: List[UploadFile] = File(...)):
204
  file_details = []
205
- for file in files:
206
- content = await file.read()
207
- name = f"{file.filename}"
208
- details = {"content": content, "name": name}
209
- file_details.append(details)
210
-
211
- app.state.conversational_chain = Conversational_Chain(
212
- file_details
213
- ).create_conversational_chain()
214
- print("conversational_chain_manager created")
 
 
 
 
 
 
 
215
  return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
216
 
217
 
218
  @app.get("/predict/")
219
- async def predict(
220
- query: str,
221
- ):
222
- if app.state.conversation_chain is None:
223
- system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
224
- response = app.state.llm_model.invoke(system_prompt + query)
225
- answer = response.content
226
- else:
227
- response = app.state.conversation_chain.invoke(query)
228
- answer = response["answer"]
 
 
229
 
230
  print("predict called")
231
  return {"answer": answer}
 
1
+ from fastapi import FastAPI, File, UploadFile, Depends, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
 
4
  from typing import List, Dict, Any
 
202
  @app.post("/upload_files/")
203
  async def upload_files(files: List[UploadFile] = File(...)):
204
  file_details = []
205
+ try:
206
+ for file in files:
207
+ content = await file.read()
208
+ name = f"{file.filename}"
209
+ details = {"content": content, "name": name}
210
+ file_details.append(details)
211
+ except Exception as e:
212
+ raise HTTPException(status_code=400, detail=str(e))
213
+
214
+ try:
215
+ app.state.conversational_chain = Conversational_Chain(
216
+ file_details
217
+ ).create_conversational_chain()
218
+ print("conversational_chain_manager created")
219
+ except Exception as e:
220
+ raise HTTPException(status_code=500, detail=str(e))
221
+
222
  return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
223
 
224
 
225
  @app.get("/predict/")
226
+ async def predict(query: str):
227
+ try:
228
+ if app.state.conversation_chain is None:
229
+ system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
230
+ llm_model = ChatOpenAI()
231
+ response = llm_model.invoke(system_prompt + query)
232
+ answer = response.content
233
+ else:
234
+ response = app.state.conversation_chain.invoke(query)
235
+ answer = response["answer"]
236
+ except Exception as e:
237
+ raise HTTPException(status_code=500, detail=str(e))
238
 
239
  print("predict called")
240
  return {"answer": answer}