caarleexx commited on
Commit
0775758
·
verified ·
1 Parent(s): abf48e1

Update backend/main.py

Browse files
Files changed (1) hide show
  1. backend/main.py +14 -8
backend/main.py CHANGED
@@ -1,4 +1,4 @@
1
- #--- START OF FILE main.py ---
2
 
3
  import os
4
  import io
@@ -14,7 +14,7 @@ from fastapi.responses import StreamingResponse
14
  # RAG Imports
15
  from langchain_community.document_loaders import PyPDFLoader
16
  from langchain_community.embeddings import HuggingFaceEmbeddings
17
- from langchain_text_splitters import RecursiveCharacterTextSplitter # CORRIGIDO: Nova importação
18
  from langchain_community.vectorstores import FAISS
19
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
20
  from langchain_core.output_parsers import StrOutputParser
@@ -38,9 +38,10 @@ app.add_middleware(
38
  # Define o modelo de embedding do Hugging Face (leve para CPU)
39
  HF_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
40
 
41
- # Inicializa o modelo Groq e o modelo de embedding
42
  model = ChatGroq(model=os.getenv("GROQ_MODEL", "mixtral-8x7b-32768"))
43
- # MUDANÇA: Inicializa o HuggingFaceEmbeddings na CPU
 
44
  embeddings = HuggingFaceEmbeddings(
45
  model_name=HF_EMBEDDING_MODEL,
46
  model_kwargs={'device': 'cpu'}
@@ -108,9 +109,12 @@ async def upload_document(file: UploadFile = File(...)):
108
  vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings)
109
  retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
110
 
111
- # 5. Criar a nova Chain RAG
 
112
  rag_chain = (
113
- RunnablePassthrough.assign(context=retriever | format_docs)
 
 
114
  | rag_prompt
115
  | model
116
  | StrOutputParser()
@@ -121,7 +125,7 @@ async def upload_document(file: UploadFile = File(...)):
121
  except Exception as e:
122
  print(f"Erro no processamento do arquivo: {e}")
123
  # Retorna um erro 500 para o frontend
124
- raise HTTPException(status_code=500, detail=f"Falha ao processar o arquivo: {e}. Verifique se o modelo HuggingFace foi baixado corretamente.")
125
  finally:
126
  # Limpeza: deletar o arquivo temporário
127
  if 'temp_path' in locals() and os.path.exists(temp_path):
@@ -146,6 +150,7 @@ async def chat(request: ChatRequest):
146
  async def stream_generator():
147
  try:
148
  # 'astream' é o método de streaming assíncrono do LangChain
 
149
  async for chunk in current_chain.astream({"input": request.content}):
150
  if chunk:
151
  yield chunk
@@ -155,4 +160,5 @@ async def chat(request: ChatRequest):
155
 
156
  # Retorna uma resposta de streaming
157
  return StreamingResponse(stream_generator(), media_type="text/plain")
158
- #--- END OF FILE main.py ---
 
 
1
+ #--- START OF FILE main (1).py ---
2
 
3
  import os
4
  import io
 
14
  # RAG Imports
15
  from langchain_community.document_loaders import PyPDFLoader
16
  from langchain_community.embeddings import HuggingFaceEmbeddings
17
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
18
  from langchain_community.vectorstores import FAISS
19
  from langchain_core.runnables import RunnablePassthrough, RunnableLambda
20
  from langchain_core.output_parsers import StrOutputParser
 
38
  # Define o modelo de embedding do Hugging Face (leve para CPU)
39
  HF_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
40
 
41
+ # Inicializa o modelo Groq
42
  model = ChatGroq(model=os.getenv("GROQ_MODEL", "mixtral-8x7b-32768"))
43
+
44
+ # Inicializa o HuggingFaceEmbeddings na CPU
45
  embeddings = HuggingFaceEmbeddings(
46
  model_name=HF_EMBEDDING_MODEL,
47
  model_kwargs={'device': 'cpu'}
 
109
  vectorstore = FAISS.from_documents(documents=splits, embedding=embeddings)
110
  retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
111
 
112
+ # 5. Criar a nova Chain RAG (CORRIGIDO)
113
+ # O lambda extrai apenas o texto da pergunta ("input") do dicionário que chega
114
  rag_chain = (
115
+ RunnablePassthrough.assign(
116
+ context=(lambda x: x["input"]) | retriever | format_docs
117
+ )
118
  | rag_prompt
119
  | model
120
  | StrOutputParser()
 
125
  except Exception as e:
126
  print(f"Erro no processamento do arquivo: {e}")
127
  # Retorna um erro 500 para o frontend
128
+ raise HTTPException(status_code=500, detail=f"Falha ao processar o arquivo: {e}")
129
  finally:
130
  # Limpeza: deletar o arquivo temporário
131
  if 'temp_path' in locals() and os.path.exists(temp_path):
 
150
  async def stream_generator():
151
  try:
152
  # 'astream' é o método de streaming assíncrono do LangChain
153
+ # Passamos {"input": ...} que será interceptado pelo lambda definido acima
154
  async for chunk in current_chain.astream({"input": request.content}):
155
  if chunk:
156
  yield chunk
 
160
 
161
  # Retorna uma resposta de streaming
162
  return StreamingResponse(stream_generator(), media_type="text/plain")
163
+
164
+ #--- END OF FILE main (1).py ---