Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import FastAPI, File, UploadFile, Depends
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
|
| 4 |
from typing import List, Dict, Any
|
|
@@ -202,30 +202,39 @@ app.state.conversation_chain = None
|
|
| 202 |
@app.post("/upload_files/")
|
| 203 |
async def upload_files(files: List[UploadFile] = File(...)):
|
| 204 |
file_details = []
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
|
| 216 |
|
| 217 |
|
| 218 |
@app.get("/predict/")
|
| 219 |
-
async def predict(
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
| 229 |
|
| 230 |
print("predict called")
|
| 231 |
return {"answer": answer}
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
|
| 4 |
from typing import List, Dict, Any
|
|
|
|
| 202 |
@app.post("/upload_files/")
|
| 203 |
async def upload_files(files: List[UploadFile] = File(...)):
|
| 204 |
file_details = []
|
| 205 |
+
try:
|
| 206 |
+
for file in files:
|
| 207 |
+
content = await file.read()
|
| 208 |
+
name = f"{file.filename}"
|
| 209 |
+
details = {"content": content, "name": name}
|
| 210 |
+
file_details.append(details)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
app.state.conversational_chain = Conversational_Chain(
|
| 216 |
+
file_details
|
| 217 |
+
).create_conversational_chain()
|
| 218 |
+
print("conversational_chain_manager created")
|
| 219 |
+
except Exception as e:
|
| 220 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 221 |
+
|
| 222 |
return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
|
| 223 |
|
| 224 |
|
| 225 |
@app.get("/predict/")
|
| 226 |
+
async def predict(query: str):
|
| 227 |
+
try:
|
| 228 |
+
if app.state.conversation_chain is None:
|
| 229 |
+
system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
|
| 230 |
+
llm_model = ChatOpenAI()
|
| 231 |
+
response = llm_model.invoke(system_prompt + query)
|
| 232 |
+
answer = response.content
|
| 233 |
+
else:
|
| 234 |
+
response = app.state.conversation_chain.invoke(query)
|
| 235 |
+
answer = response["answer"]
|
| 236 |
+
except Exception as e:
|
| 237 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 238 |
|
| 239 |
print("predict called")
|
| 240 |
return {"answer": answer}
|