krishbaresha commited on
Commit
dfd9cf6
·
verified ·
1 Parent(s): 2881d3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -175
app.py CHANGED
@@ -1,188 +1,147 @@
1
  import streamlit as st
2
- import os
3
  from groq import Groq
 
4
  from PyPDF2 import PdfReader
5
  import requests
 
 
 
6
 
7
  # ---------------------------
8
  # PAGE CONFIG
9
  # ---------------------------
10
- st.set_page_config(page_title="Krish GPT Pro", layout="wide")
11
-
12
- # ---------------------------
13
- # THEME & CSS
14
- # ---------------------------
15
- st.markdown("""
16
- <style>
17
- body {
18
- background-color: #0a0b0d;
19
- color: #d1d5db;
20
- }
21
- .chat-container {
22
- max-width: 800px;
23
- margin: auto;
24
- padding-bottom: 180px;
25
- height: 70vh;
26
- overflow-y: auto;
27
- }
28
- .chat-bubble {
29
- padding: 12px 16px;
30
- border-radius: 15px;
31
- margin-bottom: 8px;
32
- display: inline-block;
33
- max-width: 75%;
34
- word-wrap: break-word;
35
- font-size: 15px;
36
- line-height: 1.4;
37
- }
38
- .user {
39
- background-color: #4f46e5;
40
- color: white;
41
- margin-left: auto;
42
- }
43
- .assistant {
44
- background-color: #1f2937;
45
- color: #d1d5db;
46
- margin-right: auto;
47
- }
48
- .input-container {
49
- display: flex;
50
- gap: 10px;
51
- position: fixed;
52
- bottom: 20px;
53
- width: 80%;
54
- max-width: 800px;
55
- margin-left: auto;
56
- margin-right: auto;
57
- background-color: #111827;
58
- padding: 10px;
59
- border-radius: 12px;
60
- box-shadow: 0 0 10px rgba(0,0,0,0.5);
61
- }
62
- textarea {
63
- border-radius: 12px;
64
- padding: 10px;
65
- flex: 1;
66
- background-color: #1f2937;
67
- color: #d1d5db;
68
- border: none;
69
- font-size: 16px;
70
- resize: none;
71
- min-height: 40px;
72
- max-height: 150px;
73
- }
74
- button {
75
- background-color: #4f46e5;
76
- color: white;
77
- padding: 10px 16px;
78
- border-radius: 12px;
79
- border: none;
80
- cursor: pointer;
81
- }
82
- button:hover {
83
- background-color: #4338ca;
84
- }
85
- input[type="file"] {
86
- border-radius: 12px;
87
- padding: 5px;
88
- background-color: #1f2937;
89
- color: #d1d5db;
90
- }
91
- </style>
92
- """, unsafe_allow_html=True)
93
-
94
- # ---------------------------
95
- # SESSION STATE
 
 
 
 
 
 
 
 
 
96
  # ---------------------------
97
  if "messages" not in st.session_state:
98
  st.session_state.messages = []
99
 
 
 
 
 
 
 
100
  # ---------------------------
101
- # GROQ CLIENT
102
- # ---------------------------
103
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
104
- OCR_API_KEY = os.getenv("OCR_API_KEY")
105
-
106
- # ---------------------------
107
- # DISPLAY CHAT
108
- # ---------------------------
109
- chat_container = st.empty()
110
- with chat_container.container():
111
- for msg in st.session_state.messages:
112
- role_class = "user" if msg["role"] == "user" else "assistant"
113
- st.markdown(
114
- f'<div class="chat-bubble {role_class}">{msg["content"]}</div>',
115
- unsafe_allow_html=True
116
- )
117
-
118
- # ---------------------------
119
- # INPUT + FILE UPLOAD (Merged)
120
- # ---------------------------
121
- with st.form("chat_form", clear_on_submit=True):
122
- st.markdown('<div class="input-container">', unsafe_allow_html=True)
123
- col1, col2, col3 = st.columns([6, 3, 1])
124
- with col1:
125
- prompt = st.text_area(
126
- "Type a message...",
127
- key="input_text",
128
- placeholder="Press Enter to send, Ctrl+Enter for new line",
129
- height=50
130
- )
131
- with col2:
132
- uploaded_file = st.file_uploader("", label_visibility="collapsed")
133
- with col3:
134
- submitted = st.form_submit_button("Send")
135
- st.markdown('</div>', unsafe_allow_html=True)
136
-
137
- if submitted and (prompt.strip() != "" or uploaded_file):
138
- context = ""
139
- if uploaded_file:
140
- if uploaded_file.type == "application/pdf":
141
- reader = PdfReader(uploaded_file)
142
- for page in reader.pages:
143
- text = page.extract_text()
144
- if text:
145
- context += text
146
- else:
147
- try:
148
- res = requests.post(
149
- "https://api.ocr.space/parse/image",
150
- files={"file": uploaded_file},
151
- data={"apikey": OCR_API_KEY}
152
- )
153
- context = res.json()['ParsedResults'][0]['ParsedText']
154
- except:
155
- context = ""
156
-
157
- st.session_state.messages.append({"role": "user", "content": prompt})
158
-
159
- final_prompt = (prompt or "") + "\n" + context[:2000]
160
-
161
- with st.spinner("🤖 Thinking..."):
162
- try:
163
- response = client.chat.completions.create(
164
- model="llama-3.3-70b-versatile",
165
- messages=[{"role": "user", "content": final_prompt}]
166
- )
167
- reply = response.choices[0].message.content
168
- except:
169
- reply = "⚠️ Something went wrong. Try again."
170
-
171
- st.session_state.messages.append({"role": "assistant", "content": reply})
172
-
173
- # ---------------------------
174
- # JS Trick: Enter = Send, Ctrl+Enter = New line
175
- # ---------------------------
176
- st.markdown("""
177
- <script>
178
- const textarea = window.parent.document.querySelector('textarea');
179
- if (textarea) {
180
- textarea.addEventListener('keydown', function(e){
181
- if(e.key === 'Enter' && !e.ctrlKey){
182
- e.preventDefault();
183
- document.querySelector('button[kind="primary"]').click();
184
- }
185
- });
186
- }
187
- </script>
188
- """, unsafe_allow_html=True)
 
1
  import streamlit as st
 
2
  from groq import Groq
3
+ import os
4
  from PyPDF2 import PdfReader
5
  import requests
6
+ import numpy as np
7
+ import faiss
8
+ from sentence_transformers import SentenceTransformer
9
 
10
  # ---------------------------
11
  # PAGE CONFIG
12
  # ---------------------------
13
+ st.set_page_config(page_title="Krish GPT Multi-Modal RAG", layout="wide")
14
+ st.title("🤖 Krish GPT Multi-Modal RAG")
15
+ st.caption("PDF + Image OCR + RAG using Groq LLM 🚀")
16
+
17
+ # ---------------------------
18
+ # API KEYS
19
+ # ---------------------------
20
+ groq_api_key = os.getenv("GROQ_API_KEY")
21
+ ocr_api_key = os.getenv("OCR_API_KEY")
22
+
23
+ if not groq_api_key:
24
+ groq_api_key = st.text_input("Enter GROQ API Key", type="password")
25
+
26
+ if not ocr_api_key:
27
+ ocr_api_key = st.text_input("Enter OCR.Space API Key", type="password")
28
+
29
+ if not groq_api_key or not ocr_api_key:
30
+ st.stop()
31
+
32
+ client = Groq(api_key=groq_api_key)
33
+
34
+ # ---------------------------
35
+ # EMBEDDING MODEL
36
+ # ---------------------------
37
+ @st.cache_resource
38
+ def load_embedder():
39
+ return SentenceTransformer("all-MiniLM-L6-v2")
40
+
41
+ embedder = load_embedder()
42
+
43
+ # ---------------------------
44
+ # OCR Function
45
+ # ---------------------------
46
+ def ocr_space_image(file, api_key):
47
+ url = "https://api.ocr.space/parse/image"
48
+ files = {'file': file}
49
+ data = {'apikey': api_key, 'language': 'eng'}
50
+ r = requests.post(url, files=files, data=data)
51
+ try:
52
+ result = r.json()
53
+ text = result['ParsedResults'][0]['ParsedText']
54
+ except:
55
+ text = ""
56
+ return text
57
+
58
+ # ---------------------------
59
+ # FILE UPLOAD
60
+ # ---------------------------
61
+ uploaded_file = st.file_uploader(
62
+ "Upload PDF or Image", type=["pdf", "png", "jpg", "jpeg"]
63
+ )
64
+ file_text = ""
65
+
66
+ if uploaded_file:
67
+ if uploaded_file.type == "application/pdf":
68
+ reader = PdfReader(uploaded_file)
69
+ for page in reader.pages:
70
+ t = page.extract_text()
71
+ if t:
72
+ file_text += t
73
+ elif "image" in uploaded_file.type:
74
+ file_text = ocr_space_image(uploaded_file, ocr_api_key)
75
+
76
+ # ---------------------------
77
+ # TEXT CHUNKING & FAISS
78
+ # ---------------------------
79
+ def chunk_text(text, chunk_size=500):
80
+ chunks = []
81
+ for i in range(0, len(text), chunk_size):
82
+ chunks.append(text[i:i+chunk_size])
83
+ return chunks
84
+
85
+ def build_index(chunks):
86
+ embeddings = embedder.encode(chunks)
87
+ dim = embeddings.shape[1]
88
+ index = faiss.IndexFlatL2(dim)
89
+ index.add(np.array(embeddings))
90
+ return index, embeddings
91
+
92
+ def search(query, chunks, index):
93
+ q_emb = embedder.encode([query])
94
+ D, I = index.search(np.array(q_emb), k=min(3, len(chunks)))
95
+ results = [chunks[i] for i in I[0]]
96
+ return "\n".join(results)
97
+
98
+ # ---------------------------
99
+ # PROCESS FILE
100
+ # ---------------------------
101
+ if uploaded_file and file_text:
102
+ chunks = chunk_text(file_text)
103
+ index, embeddings = build_index(chunks)
104
+ st.session_state.rag_data = (chunks, index)
105
+
106
+ # ---------------------------
107
+ # CHAT MEMORY
108
  # ---------------------------
109
  if "messages" not in st.session_state:
110
  st.session_state.messages = []
111
 
112
+ for msg in st.session_state.messages:
113
+ with st.chat_message(msg["role"]):
114
+ st.markdown(msg["content"])
115
+
116
+ # ---------------------------
117
+ # USER PROMPT
118
  # ---------------------------
119
+ prompt = st.chat_input("Ask anything...")
120
+
121
+ if prompt:
122
+ st.session_state.messages.append({"role": "user", "content": prompt})
123
+ with st.chat_message("user"):
124
+ st.markdown(prompt)
125
+
126
+ context = ""
127
+ if "rag_data" in st.session_state:
128
+ chunks, index = st.session_state.rag_data
129
+ context = search(prompt, chunks, index)
130
+
131
+ with st.chat_message("assistant"):
132
+ try:
133
+ response = client.chat.completions.create(
134
+ model="llama-3.3-70b-versatile",
135
+ messages=[
136
+ {"role": "system", "content": f"Context:\n{context}"},
137
+ *st.session_state.messages
138
+ ],
139
+ temperature=0.7,
140
+ max_tokens=1024
141
+ )
142
+ reply = response.choices[0].message.content
143
+ except Exception as e:
144
+ reply = f" Error: {str(e)}"
145
+
146
+ st.markdown(reply)
147
+ st.session_state.messages.append({"role": "assistant", "content": reply})