zeeshan4801 commited on
Commit
74b3e59
·
verified ·
1 Parent(s): cf8303f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %%writefile app.py
2
+ # ============================================
3
+ # Civil Engineering RAG (ASTM) - app.py
4
+ # ============================================
5
+ import os
6
+ import fitz # PyMuPDF
7
+ import faiss
8
+ import numpy as np
9
+ import gradio as gr
10
+ from typing import List
11
+ from groq import Groq
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ # --------------------------
15
+ # Config
16
+ # --------------------------
17
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
18
+ if not GROQ_API_KEY:
19
+ raise RuntimeError("GROQ_API_KEY missing. Set it before running: os.environ['GROQ_API_KEY']='...'")
20
+
21
+ # Change these if your filenames differ:
22
+ DOC_PATHS = [
23
+ "docs/ASTM1.pdf",
24
+ "docs/ASTM2.pdf",
25
+ ]
26
+
27
+ # Embedding model (free & small; good for Colab)
28
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
29
+
30
+ # --------------------------
31
+ # Clients / Models
32
+ # --------------------------
33
+ client = Groq(api_key=GROQ_API_KEY)
34
+ embedder = SentenceTransformer(EMBED_MODEL)
35
+
36
+ # --------------------------
37
+ # PDF text extraction
38
+ # --------------------------
39
+ def extract_text_from_pdf(file_path: str) -> str:
40
+ text = []
41
+ with fitz.open(file_path) as doc:
42
+ for page in doc:
43
+ text.append(page.get_text("text"))
44
+ return "\n".join(text)
45
+
46
+ # --------------------------
47
+ # Simple character-based chunking with overlap
48
+ # --------------------------
49
+ def chunk_text(text: str, chunk_size: int = 800, overlap: int = 120) -> List[str]:
50
+ chunks = []
51
+ start = 0
52
+ n = len(text)
53
+ while start < n:
54
+ end = min(start + chunk_size, n)
55
+ chunk = text[start:end].strip()
56
+ if chunk:
57
+ chunks.append(chunk)
58
+ start = end - overlap
59
+ if start < 0:
60
+ start = 0
61
+ return chunks
62
+
63
+ # --------------------------
64
+ # Build FAISS index
65
+ # --------------------------
66
+ def build_faiss_index(paths: List[str]):
67
+ texts = []
68
+ vectors = []
69
+
70
+ for p in paths:
71
+ if not os.path.exists(p):
72
+ raise FileNotFoundError(f"Document not found: {p}")
73
+ raw = extract_text_from_pdf(p)
74
+ chunks = chunk_text(raw)
75
+ if not chunks:
76
+ continue
77
+ embs = embedder.encode(chunks, convert_to_numpy=True, show_progress_bar=True)
78
+ texts.extend(chunks)
79
+ vectors.append(embs.astype("float32"))
80
+
81
+ if not texts:
82
+ raise RuntimeError("No text extracted from provided PDFs.")
83
+
84
+ vectors = np.vstack(vectors).astype("float32")
85
+ index = faiss.IndexFlatL2(vectors.shape[1])
86
+ index.add(vectors)
87
+
88
+ # Persist (optional)
89
+ os.makedirs("faiss_index", exist_ok=True)
90
+ faiss.write_index(index, "faiss_index/index.faiss")
91
+ np.save("faiss_index/corpus.npy", np.array(texts, dtype=object))
92
+
93
+ return index, texts
94
+
95
+ def load_or_build_index(paths: List[str]):
96
+ idx_path = "faiss_index/index.faiss"
97
+ corpus_path = "faiss_index/corpus.npy"
98
+ if os.path.exists(idx_path) and os.path.exists(corpus_path):
99
+ index = faiss.read_index(idx_path)
100
+ corpus = np.load(corpus_path, allow_pickle=True).tolist()
101
+ return index, corpus
102
+ return build_faiss_index(paths)
103
+
104
+ # Build on import (so Gradio has it)
105
+ INDEX, CORPUS = load_or_build_index(DOC_PATHS)
106
+
107
+ # --------------------------
108
+ # Retrieval
109
+ # --------------------------
110
+ def retrieve_context(query: str, top_k: int = 4) -> str:
111
+ q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32")
112
+ distances, indices = INDEX.search(q_emb, top_k)
113
+ selected = []
114
+ for i in indices[0]:
115
+ if 0 <= i < len(CORPUS):
116
+ selected.append(CORPUS[i])
117
+ return "\n\n---\n\n".join(selected)
118
+
119
+ # --------------------------
120
+ # LLM call via Groq
121
+ # --------------------------
122
+ SYSTEM_PROMPT = (
123
+ "You are a helpful Civil Engineering assistant. "
124
+ "Use ONLY the provided ASTM context to answer. "
125
+ "If the answer isn't in context, say you cannot find it in the provided documents."
126
+ )
127
+
128
+ def ask_groq(query: str, top_k: int = 4, model: str = "llama-3.3-70b-versatile") -> str:
129
+ context = retrieve_context(query, top_k=top_k)
130
+ prompt = f"""{SYSTEM_PROMPT}
131
+
132
+ Context (ASTM excerpts):
133
+ {context}
134
+
135
+ Question:
136
+ {query}
137
+
138
+ Answer clearly and cite phrases only if present in the context above.
139
+ """
140
+ completion = client.chat.completions.create(
141
+ model=model,
142
+ messages=[{"role": "user", "content": prompt}],
143
+ temperature=0.2,
144
+ )
145
+ return completion.choices[0].message.content
146
+
147
+ # --------------------------
148
+ # Gradio UI
149
+ # --------------------------
150
+ def ui_ask(query: str, top_k: int):
151
+ try:
152
+ return ask_groq(query, top_k=top_k)
153
+ except Exception as e:
154
+ return f"Error: {e}"
155
+
156
+ with gr.Blocks(title="Civil Engineering RAG (ASTM)") as demo:
157
+ gr.Markdown("# 🏗️ Civil Engineering RAG (ASTM)\nAsk questions grounded in your uploaded ASTM PDFs.")
158
+ with gr.Row():
159
+ inp = gr.Textbox(label="Your question", placeholder="e.g., What is the acceptable slump range for Class A concrete?")
160
+ k = gr.Slider(1, 10, value=4, step=1, label="Top-K passages to retrieve")
161
+ out = gr.Textbox(label="Answer")
162
+ btn = gr.Button("Ask")
163
+ btn.click(ui_ask, inputs=[inp, k], outputs=[out])
164
+ gr.Markdown("Tip: If you change PDFs, **restart runtime** and re-run cells to rebuild the index.")
165
+
166
+ if __name__ == "__main__":
167
+ demo.launch(share=True)