File size: 4,221 Bytes
33d6cab
889699f
dfd9cf6
606fd71
889699f
dfd9cf6
 
 
33d6cab
889699f
8e391ed
889699f
dfd9cf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889699f
0ed38c7
 
 
dfd9cf6
 
 
 
 
 
0ed38c7
dfd9cf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
from groq import Groq
import os
from PyPDF2 import PdfReader
import requests
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

# ---------------------------
# PAGE CONFIG
# ---------------------------
st.set_page_config(page_title="Krish GPT Multi-Modal RAG", layout="wide")
st.title("🤖 Krish GPT Multi-Modal RAG")
st.caption("PDF + Image OCR + RAG using Groq LLM 🚀")

# ---------------------------
# API KEYS
# ---------------------------
groq_api_key = os.getenv("GROQ_API_KEY")
ocr_api_key = os.getenv("OCR_API_KEY")

if not groq_api_key:
    groq_api_key = st.text_input("Enter GROQ API Key", type="password")

if not ocr_api_key:
    ocr_api_key = st.text_input("Enter OCR.Space API Key", type="password")

if not groq_api_key or not ocr_api_key:
    st.stop()

client = Groq(api_key=groq_api_key)

# ---------------------------
# EMBEDDING MODEL
# ---------------------------
@st.cache_resource
def load_embedder():
    return SentenceTransformer("all-MiniLM-L6-v2")

embedder = load_embedder()

# ---------------------------
# OCR Function
# ---------------------------
def ocr_space_image(file, api_key):
    url = "https://api.ocr.space/parse/image"
    files = {'file': file}
    data = {'apikey': api_key, 'language': 'eng'}
    r = requests.post(url, files=files, data=data)
    try:
        result = r.json()
        text = result['ParsedResults'][0]['ParsedText']
    except:
        text = ""
    return text

# ---------------------------
# FILE UPLOAD
# ---------------------------
uploaded_file = st.file_uploader(
    "Upload PDF or Image", type=["pdf", "png", "jpg", "jpeg"]
)
file_text = ""

if uploaded_file:
    if uploaded_file.type == "application/pdf":
        reader = PdfReader(uploaded_file)
        for page in reader.pages:
            t = page.extract_text()
            if t:
                file_text += t
    elif "image" in uploaded_file.type:
        file_text = ocr_space_image(uploaded_file, ocr_api_key)

# ---------------------------
# TEXT CHUNKING & FAISS
# ---------------------------
def chunk_text(text, chunk_size=500):
    chunks = []
    for i in range(0, len(text), chunk_size):
        chunks.append(text[i:i+chunk_size])
    return chunks

def build_index(chunks):
    embeddings = embedder.encode(chunks)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(np.array(embeddings))
    return index, embeddings

def search(query, chunks, index):
    q_emb = embedder.encode([query])
    D, I = index.search(np.array(q_emb), k=min(3, len(chunks)))
    results = [chunks[i] for i in I[0]]
    return "\n".join(results)

# ---------------------------
# PROCESS FILE
# ---------------------------
if uploaded_file and file_text:
    chunks = chunk_text(file_text)
    index, embeddings = build_index(chunks)
    st.session_state.rag_data = (chunks, index)

# ---------------------------
# CHAT MEMORY
# ---------------------------
if "messages" not in st.session_state:
    st.session_state.messages = []

for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# ---------------------------
# USER PROMPT
# ---------------------------
prompt = st.chat_input("Ask anything...")

if prompt:
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    context = ""
    if "rag_data" in st.session_state:
        chunks, index = st.session_state.rag_data
        context = search(prompt, chunks, index)

    with st.chat_message("assistant"):
        try:
            response = client.chat.completions.create(
                model="llama-3.3-70b-versatile",
                messages=[
                    {"role": "system", "content": f"Context:\n{context}"},
                    *st.session_state.messages
                ],
                temperature=0.7,
                max_tokens=1024
            )
            reply = response.choices[0].message.content
        except Exception as e:
            reply = f"❌ Error: {str(e)}"

        st.markdown(reply)
        st.session_state.messages.append({"role": "assistant", "content": reply})