Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from groq import Groq | |
| import os | |
| from PyPDF2 import PdfReader | |
| import requests | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| # --------------------------- | |
| # PAGE CONFIG | |
| # --------------------------- | |
| st.set_page_config(page_title="Krish GPT Multi-Modal RAG", layout="wide") | |
| st.title("🤖 Krish GPT Multi-Modal RAG") | |
| st.caption("PDF + Image OCR + RAG using Groq LLM 🚀") | |
| # --------------------------- | |
| # API KEYS | |
| # --------------------------- | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| ocr_api_key = os.getenv("OCR_API_KEY") | |
| if not groq_api_key: | |
| groq_api_key = st.text_input("Enter GROQ API Key", type="password") | |
| if not ocr_api_key: | |
| ocr_api_key = st.text_input("Enter OCR.Space API Key", type="password") | |
| if not groq_api_key or not ocr_api_key: | |
| st.stop() | |
| client = Groq(api_key=groq_api_key) | |
| # --------------------------- | |
| # EMBEDDING MODEL | |
| # --------------------------- | |
| def load_embedder(): | |
| return SentenceTransformer("all-MiniLM-L6-v2") | |
| embedder = load_embedder() | |
| # --------------------------- | |
| # OCR Function | |
| # --------------------------- | |
| def ocr_space_image(file, api_key): | |
| url = "https://api.ocr.space/parse/image" | |
| files = {'file': file} | |
| data = {'apikey': api_key, 'language': 'eng'} | |
| r = requests.post(url, files=files, data=data) | |
| try: | |
| result = r.json() | |
| text = result['ParsedResults'][0]['ParsedText'] | |
| except: | |
| text = "" | |
| return text | |
| # --------------------------- | |
| # FILE UPLOAD | |
| # --------------------------- | |
| uploaded_file = st.file_uploader( | |
| "Upload PDF or Image", type=["pdf", "png", "jpg", "jpeg"] | |
| ) | |
| file_text = "" | |
| if uploaded_file: | |
| if uploaded_file.type == "application/pdf": | |
| reader = PdfReader(uploaded_file) | |
| for page in reader.pages: | |
| t = page.extract_text() | |
| if t: | |
| file_text += t | |
| elif "image" in uploaded_file.type: | |
| file_text = ocr_space_image(uploaded_file, ocr_api_key) | |
| # --------------------------- | |
| # TEXT CHUNKING & FAISS | |
| # --------------------------- | |
| def chunk_text(text, chunk_size=500): | |
| chunks = [] | |
| for i in range(0, len(text), chunk_size): | |
| chunks.append(text[i:i+chunk_size]) | |
| return chunks | |
| def build_index(chunks): | |
| embeddings = embedder.encode(chunks) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(np.array(embeddings)) | |
| return index, embeddings | |
| def search(query, chunks, index): | |
| q_emb = embedder.encode([query]) | |
| D, I = index.search(np.array(q_emb), k=min(3, len(chunks))) | |
| results = [chunks[i] for i in I[0]] | |
| return "\n".join(results) | |
| # --------------------------- | |
| # PROCESS FILE | |
| # --------------------------- | |
| if uploaded_file and file_text: | |
| chunks = chunk_text(file_text) | |
| index, embeddings = build_index(chunks) | |
| st.session_state.rag_data = (chunks, index) | |
| # --------------------------- | |
| # CHAT MEMORY | |
| # --------------------------- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # --------------------------- | |
| # USER PROMPT | |
| # --------------------------- | |
| prompt = st.chat_input("Ask anything...") | |
| if prompt: | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| context = "" | |
| if "rag_data" in st.session_state: | |
| chunks, index = st.session_state.rag_data | |
| context = search(prompt, chunks, index) | |
| with st.chat_message("assistant"): | |
| try: | |
| response = client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", | |
| messages=[ | |
| {"role": "system", "content": f"Context:\n{context}"}, | |
| *st.session_state.messages | |
| ], | |
| temperature=0.7, | |
| max_tokens=1024 | |
| ) | |
| reply = response.choices[0].message.content | |
| except Exception as e: | |
| reply = f"❌ Error: {str(e)}" | |
| st.markdown(reply) | |
| st.session_state.messages.append({"role": "assistant", "content": reply}) |