|
|
import os |
|
|
import streamlit as st |
|
|
from groq import Groq |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain.text_splitter import CharacterTextSplitter |
|
|
from langchain_community.document_loaders import TextLoader, PyPDFLoader |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.llms.base import LLM |
|
|
from tempfile import NamedTemporaryFile |
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
if not GROQ_API_KEY: |
|
|
st.error("โ GROQ_API_KEY is not set in environment variables.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
groq_client = Groq(api_key=GROQ_API_KEY) |
|
|
|
|
|
|
|
|
class GroqLLM(LLM): |
|
|
def __init__(self, model_name="llama3-8b-8192"): |
|
|
self.model_name = model_name |
|
|
|
|
|
def _call(self, prompt, stop=None): |
|
|
response = groq_client.chat.completions.create( |
|
|
model=self.model_name, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
) |
|
|
return response.choices[0].message.content.strip() |
|
|
|
|
|
@property |
|
|
def _llm_type(self): |
|
|
return "groq_llm" |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Groq RAG App", layout="centered") |
|
|
st.title("๐ RAG App with Groq + LangChain + FAISS") |
|
|
st.write("Upload a PDF or TXT file, ask a question, and get smart answers.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload your document", type=["pdf", "txt"]) |
|
|
|
|
|
if uploaded_file: |
|
|
with NamedTemporaryFile(delete=False) as tmp_file: |
|
|
tmp_file.write(uploaded_file.read()) |
|
|
tmp_path = tmp_file.name |
|
|
|
|
|
|
|
|
if uploaded_file.type == "application/pdf": |
|
|
loader = PyPDFLoader(tmp_path) |
|
|
else: |
|
|
loader = TextLoader(tmp_path) |
|
|
|
|
|
docs = loader.load() |
|
|
|
|
|
|
|
|
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
texts = splitter.split_documents(docs) |
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings() |
|
|
db = FAISS.from_documents(texts, embeddings) |
|
|
|
|
|
|
|
|
retriever = db.as_retriever() |
|
|
qa_chain = RetrievalQA.from_chain_type(llm=GroqLLM(), retriever=retriever) |
|
|
|
|
|
|
|
|
query = st.text_input("๐ Ask a question about the document:") |
|
|
|
|
|
if query: |
|
|
with st.spinner("Thinking..."): |
|
|
result = qa_chain.run(query) |
|
|
st.markdown("### ๐ง Answer:") |
|
|
st.success(result) |
|
|
|