|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import fitz |
|
|
import faiss |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import tempfile |
|
|
from typing import List |
|
|
from groq import Groq |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "") |
|
|
if not GROQ_API_KEY: |
|
|
raise RuntimeError("β Missing GROQ_API_KEY. Please add it in Hugging Face β Settings β Secrets.") |
|
|
|
|
|
|
|
|
client = Groq(api_key=GROQ_API_KEY) |
|
|
embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
|
|
|
INDEX, CORPUS = None, [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_text_from_pdf(file_path: str) -> str: |
|
|
try: |
|
|
text = "" |
|
|
with fitz.open(file_path) as doc: |
|
|
for page in doc: |
|
|
text += page.get_text("text") |
|
|
return text |
|
|
except Exception as e: |
|
|
return f"Error extracting text from {file_path}: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 800, overlap: int = 120) -> List[str]: |
|
|
chunks = [] |
|
|
start, n = 0, len(text) |
|
|
while start < n: |
|
|
end = min(start + chunk_size, n) |
|
|
chunk = text[start:end].strip() |
|
|
if chunk: |
|
|
chunks.append(chunk) |
|
|
start = end - overlap |
|
|
if start < 0: |
|
|
start = 0 |
|
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_faiss_index(paths: List[str]): |
|
|
texts, vectors = [], [] |
|
|
for p in paths: |
|
|
text = extract_text_from_pdf(p) |
|
|
if text.startswith("Error extracting text"): |
|
|
raise RuntimeError(text) |
|
|
chunks = chunk_text(text) |
|
|
if not chunks: |
|
|
continue |
|
|
embs = embedder.encode(chunks, convert_to_numpy=True, show_progress_bar=False) |
|
|
texts.extend(chunks) |
|
|
vectors.append(embs.astype("float32")) |
|
|
|
|
|
if not texts: |
|
|
raise RuntimeError("β No valid text extracted from PDFs.") |
|
|
|
|
|
vectors = np.vstack(vectors).astype("float32") |
|
|
index = faiss.IndexFlatL2(vectors.shape[1]) |
|
|
index.add(vectors) |
|
|
return index, texts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rebuild_index_from_upload(files): |
|
|
if not files: |
|
|
return "β οΈ Please upload at least one PDF." |
|
|
|
|
|
paths = [] |
|
|
for f in files: |
|
|
try: |
|
|
|
|
|
if hasattr(f, "name") and os.path.exists(f.name): |
|
|
temp_path = f.name |
|
|
else: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: |
|
|
tmp.write(f.read()) |
|
|
temp_path = tmp.name |
|
|
paths.append(temp_path) |
|
|
except Exception as e: |
|
|
return f"β Error while saving uploaded file: {e}" |
|
|
|
|
|
try: |
|
|
global INDEX, CORPUS |
|
|
INDEX, CORPUS = build_faiss_index(paths) |
|
|
return f"β
Successfully indexed {len(paths)} PDF(s). You can now ask questions!" |
|
|
except Exception as e: |
|
|
return f"β Error while building index: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_context(query: str, top_k: int = 4) -> str: |
|
|
if INDEX is None: |
|
|
return "β οΈ Please upload and index PDFs first." |
|
|
q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32") |
|
|
distances, indices = INDEX.search(q_emb, top_k) |
|
|
selected = [CORPUS[i] for i in indices[0] if 0 <= i < len(CORPUS)] |
|
|
return "\n\n---\n\n".join(selected) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"You are a helpful Civil Engineering assistant. " |
|
|
"Use ONLY the provided ASTM or uploaded document context to answer. " |
|
|
"If the answer isn't in context, say you cannot find it." |
|
|
) |
|
|
|
|
|
def ask_groq(query: str, top_k: int = 4, model: str = "llama-3.3-70b-versatile") -> str: |
|
|
if INDEX is None: |
|
|
return "β οΈ Please upload PDFs first." |
|
|
|
|
|
context = retrieve_context(query, top_k) |
|
|
if not context.strip(): |
|
|
return "β οΈ No relevant information found in the uploaded PDFs." |
|
|
|
|
|
prompt = f"""{SYSTEM_PROMPT} |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{query} |
|
|
""" |
|
|
|
|
|
try: |
|
|
completion = client.chat.completions.create( |
|
|
model=model, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=0.2, |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
except Exception as e: |
|
|
return f"β LLM Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ui_ask(query: str, top_k: int): |
|
|
try: |
|
|
return ask_groq(query, top_k=top_k) |
|
|
except Exception as e: |
|
|
return f"β Error: {e}" |
|
|
|
|
|
with gr.Blocks(title="Civil Engineering RAG (ASTM)") as demo: |
|
|
gr.Markdown("## ποΈ Civil Engineering RAG\nUpload ASTM or civil-engineering PDFs, build an index, and ask questions.") |
|
|
|
|
|
with gr.Row(): |
|
|
uploader = gr.File(label="π Upload PDFs", file_count="multiple", file_types=[".pdf"]) |
|
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
uploader.upload(rebuild_index_from_upload, uploader, status) |
|
|
|
|
|
gr.Markdown("---") |
|
|
inp = gr.Textbox(label="Your Question", placeholder="e.g., What is the curing time for concrete as per ASTM?") |
|
|
k = gr.Slider(1, 10, value=4, step=1, label="Top-K passages") |
|
|
out = gr.Textbox(label="Answer") |
|
|
btn = gr.Button("Ask") |
|
|
btn.click(ui_ask, inputs=[inp, k], outputs=[out]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|