REFORMER_AI / rag_processor.py
bsmith3715's picture
Update rag_processor.py
97857a2 verified
import os
from typing import List, Generator, AsyncGenerator
import openai
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from dotenv import load_dotenv
from langchain_core.documents import Document
load_dotenv()
class RAGProcessor:
def __init__(self, model_name: str = "bsmith3715/legal-ft-demo_final"):
self.model = SentenceTransformer(model_name)
self.index = None
self.documents = []
self.openai_client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def add_documents(self, documents: List[Document]):
"""Add documents to the RAG system."""
self.documents = [doc.page_content for doc in documents]
embeddings = self.model.encode(self.documents)
# Create FAISS index
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype('float32'))
def retrieve_relevant_context(self, query: str, k: int = 3) -> List[str]:
"""Retrieve relevant documents for a given query."""
if not self.index:
return []
query_embedding = self.model.encode([query])
distances, indices = self.index.search(query_embedding.astype('float32'), k)
return [self.documents[i] for i in indices[0]]
async def generate_response(self, query: str) -> AsyncGenerator[str, None]:
"""Generate a streaming response using OpenAI API with retrieved context."""
relevant_docs = self.retrieve_relevant_context(query)
context = "\n".join(relevant_docs)
prompt = f"""Context information is below.
---------------------
{context}
---------------------
Given the context information, please answer the following question. If the context doesn't contain relevant information, say so.
Question: {query}
Answer:"""
stream = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful Pilates instructor assistant. Use the provided context to answer questions accurately."},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=1000,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content