import streamlit as st from groq import Groq import os from PyPDF2 import PdfReader import requests import numpy as np import faiss from sentence_transformers import SentenceTransformer # --------------------------- # PAGE CONFIG # --------------------------- st.set_page_config(page_title="Krish GPT Multi-Modal RAG", layout="wide") st.title("🤖 Krish GPT Multi-Modal RAG") st.caption("PDF + Image OCR + RAG using Groq LLM 🚀") # --------------------------- # API KEYS # --------------------------- groq_api_key = os.getenv("GROQ_API_KEY") ocr_api_key = os.getenv("OCR_API_KEY") if not groq_api_key: groq_api_key = st.text_input("Enter GROQ API Key", type="password") if not ocr_api_key: ocr_api_key = st.text_input("Enter OCR.Space API Key", type="password") if not groq_api_key or not ocr_api_key: st.stop() client = Groq(api_key=groq_api_key) # --------------------------- # EMBEDDING MODEL # --------------------------- @st.cache_resource def load_embedder(): return SentenceTransformer("all-MiniLM-L6-v2") embedder = load_embedder() # --------------------------- # OCR Function # --------------------------- def ocr_space_image(file, api_key): url = "https://api.ocr.space/parse/image" files = {'file': file} data = {'apikey': api_key, 'language': 'eng'} r = requests.post(url, files=files, data=data) try: result = r.json() text = result['ParsedResults'][0]['ParsedText'] except: text = "" return text # --------------------------- # FILE UPLOAD # --------------------------- uploaded_file = st.file_uploader( "Upload PDF or Image", type=["pdf", "png", "jpg", "jpeg"] ) file_text = "" if uploaded_file: if uploaded_file.type == "application/pdf": reader = PdfReader(uploaded_file) for page in reader.pages: t = page.extract_text() if t: file_text += t elif "image" in uploaded_file.type: file_text = ocr_space_image(uploaded_file, ocr_api_key) # --------------------------- # TEXT CHUNKING & FAISS # --------------------------- def chunk_text(text, chunk_size=500): chunks = [] for i in range(0, len(text), chunk_size): chunks.append(text[i:i+chunk_size]) return chunks def build_index(chunks): embeddings = embedder.encode(chunks) dim = embeddings.shape[1] index = faiss.IndexFlatL2(dim) index.add(np.array(embeddings)) return index, embeddings def search(query, chunks, index): q_emb = embedder.encode([query]) D, I = index.search(np.array(q_emb), k=min(3, len(chunks))) results = [chunks[i] for i in I[0]] return "\n".join(results) # --------------------------- # PROCESS FILE # --------------------------- if uploaded_file and file_text: chunks = chunk_text(file_text) index, embeddings = build_index(chunks) st.session_state.rag_data = (chunks, index) # --------------------------- # CHAT MEMORY # --------------------------- if "messages" not in st.session_state: st.session_state.messages = [] for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # --------------------------- # USER PROMPT # --------------------------- prompt = st.chat_input("Ask anything...") if prompt: st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) context = "" if "rag_data" in st.session_state: chunks, index = st.session_state.rag_data context = search(prompt, chunks, index) with st.chat_message("assistant"): try: response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[ {"role": "system", "content": f"Context:\n{context}"}, *st.session_state.messages ], temperature=0.7, max_tokens=1024 ) reply = response.choices[0].message.content except Exception as e: reply = f"❌ Error: {str(e)}" st.markdown(reply) st.session_state.messages.append({"role": "assistant", "content": reply})