Krish_GPT_Pro / app.py
krishbaresha's picture
Update app.py
dfd9cf6 verified
import streamlit as st
from groq import Groq
import os
from PyPDF2 import PdfReader
import requests
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
# ---------------------------
# PAGE CONFIG
# ---------------------------
st.set_page_config(page_title="Krish GPT Multi-Modal RAG", layout="wide")
st.title("🤖 Krish GPT Multi-Modal RAG")
st.caption("PDF + Image OCR + RAG using Groq LLM 🚀")
# ---------------------------
# API KEYS
# ---------------------------
groq_api_key = os.getenv("GROQ_API_KEY")
ocr_api_key = os.getenv("OCR_API_KEY")
if not groq_api_key:
groq_api_key = st.text_input("Enter GROQ API Key", type="password")
if not ocr_api_key:
ocr_api_key = st.text_input("Enter OCR.Space API Key", type="password")
if not groq_api_key or not ocr_api_key:
st.stop()
client = Groq(api_key=groq_api_key)
# ---------------------------
# EMBEDDING MODEL
# ---------------------------
@st.cache_resource
def load_embedder():
return SentenceTransformer("all-MiniLM-L6-v2")
embedder = load_embedder()
# ---------------------------
# OCR Function
# ---------------------------
def ocr_space_image(file, api_key):
url = "https://api.ocr.space/parse/image"
files = {'file': file}
data = {'apikey': api_key, 'language': 'eng'}
r = requests.post(url, files=files, data=data)
try:
result = r.json()
text = result['ParsedResults'][0]['ParsedText']
except:
text = ""
return text
# ---------------------------
# FILE UPLOAD
# ---------------------------
uploaded_file = st.file_uploader(
"Upload PDF or Image", type=["pdf", "png", "jpg", "jpeg"]
)
file_text = ""
if uploaded_file:
if uploaded_file.type == "application/pdf":
reader = PdfReader(uploaded_file)
for page in reader.pages:
t = page.extract_text()
if t:
file_text += t
elif "image" in uploaded_file.type:
file_text = ocr_space_image(uploaded_file, ocr_api_key)
# ---------------------------
# TEXT CHUNKING & FAISS
# ---------------------------
def chunk_text(text, chunk_size=500):
chunks = []
for i in range(0, len(text), chunk_size):
chunks.append(text[i:i+chunk_size])
return chunks
def build_index(chunks):
embeddings = embedder.encode(chunks)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(np.array(embeddings))
return index, embeddings
def search(query, chunks, index):
q_emb = embedder.encode([query])
D, I = index.search(np.array(q_emb), k=min(3, len(chunks)))
results = [chunks[i] for i in I[0]]
return "\n".join(results)
# ---------------------------
# PROCESS FILE
# ---------------------------
if uploaded_file and file_text:
chunks = chunk_text(file_text)
index, embeddings = build_index(chunks)
st.session_state.rag_data = (chunks, index)
# ---------------------------
# CHAT MEMORY
# ---------------------------
if "messages" not in st.session_state:
st.session_state.messages = []
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# ---------------------------
# USER PROMPT
# ---------------------------
prompt = st.chat_input("Ask anything...")
if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
context = ""
if "rag_data" in st.session_state:
chunks, index = st.session_state.rag_data
context = search(prompt, chunks, index)
with st.chat_message("assistant"):
try:
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[
{"role": "system", "content": f"Context:\n{context}"},
*st.session_state.messages
],
temperature=0.7,
max_tokens=1024
)
reply = response.choices[0].message.content
except Exception as e:
reply = f"❌ Error: {str(e)}"
st.markdown(reply)
st.session_state.messages.append({"role": "assistant", "content": reply})