|
|
import os |
|
|
import re |
|
|
import numpy as np |
|
|
import faiss |
|
|
import gradio as gr |
|
|
|
|
|
from pypdf import PdfReader |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from openai import OpenAI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOGETHER_API_KEY = (os.getenv("TOGETHER_API_KEY") or "").strip() |
|
|
TOGETHER_BASE_URL = os.getenv("TOGETHER_BASE_URL", "https://api.together.xyz/v1").strip() |
|
|
TOGETHER_MODEL = os.getenv("TOGETHER_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1").strip() |
|
|
|
|
|
EMBED_MODEL_NAME = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip() |
|
|
TOP_K = int(os.getenv("TOP_K", "4")) |
|
|
|
|
|
|
|
|
embedder = SentenceTransformer(EMBED_MODEL_NAME) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_text(s: str) -> str: |
|
|
s = re.sub(r"\s+", " ", s) |
|
|
return s.strip() |
|
|
|
|
|
|
|
|
def chunk_text(text: str, chunk_size=900, overlap=150): |
|
|
chunks = [] |
|
|
start = 0 |
|
|
n = len(text) |
|
|
while start < n: |
|
|
end = min(n, start + chunk_size) |
|
|
chunks.append(text[start:end]) |
|
|
start = max(0, end - overlap) |
|
|
if end == n: |
|
|
break |
|
|
return [c for c in (clean_text(x) for x in chunks) if len(c) > 30] |
|
|
|
|
|
|
|
|
def pdf_to_text(pdf_path: str) -> str: |
|
|
reader = PdfReader(pdf_path) |
|
|
pages = [] |
|
|
for p in reader.pages: |
|
|
t = p.extract_text() or "" |
|
|
if t.strip(): |
|
|
pages.append(t) |
|
|
return "\n".join(pages) |
|
|
|
|
|
|
|
|
def build_faiss_index(chunks): |
|
|
vectors = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=True) |
|
|
dim = vectors.shape[1] |
|
|
index = faiss.IndexFlatIP(dim) |
|
|
index.add(vectors.astype(np.float32)) |
|
|
return index |
|
|
|
|
|
|
|
|
def retrieve(query, index, chunks, k=TOP_K): |
|
|
qv = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32) |
|
|
scores, ids = index.search(qv, k) |
|
|
hits = [] |
|
|
for score, idx in zip(scores[0], ids[0]): |
|
|
if idx == -1: |
|
|
continue |
|
|
hits.append((float(score), chunks[int(idx)])) |
|
|
return hits |
|
|
|
|
|
|
|
|
def llm_generate(prompt: str) -> str: |
|
|
if not TOGETHER_API_KEY: |
|
|
return ( |
|
|
"β TOGETHER_API_KEY not found.\n\n" |
|
|
"Go to Space β Settings β Variables and secrets β New secret:\n" |
|
|
"Name: TOGETHER_API_KEY\n" |
|
|
"Value: your Together key\n" |
|
|
"Then restart the Space." |
|
|
) |
|
|
|
|
|
client = OpenAI(api_key=TOGETHER_API_KEY, base_url=TOGETHER_BASE_URL) |
|
|
|
|
|
try: |
|
|
resp = client.chat.completions.create( |
|
|
model=TOGETHER_MODEL, |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a helpful assistant. Follow instructions carefully."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
temperature=0.2, |
|
|
top_p=0.9, |
|
|
max_tokens=450, |
|
|
) |
|
|
return (resp.choices[0].message.content or "").strip() |
|
|
except Exception as e: |
|
|
return ( |
|
|
"β LLM call failed.\n\n" |
|
|
f"Base URL: {TOGETHER_BASE_URL}\n" |
|
|
f"Model: {TOGETHER_MODEL}\n" |
|
|
f"Error: {type(e).__name__}: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def index_pdf(pdf_file): |
|
|
if pdf_file is None: |
|
|
return None, None, "Please upload a PDF." |
|
|
|
|
|
text = pdf_to_text(pdf_file) |
|
|
if not text.strip(): |
|
|
return None, None, "Could not extract text. If itβs scanned, you need OCR." |
|
|
|
|
|
chunks = chunk_text(text) |
|
|
if len(chunks) < 2: |
|
|
return None, None, "Not enough text to build RAG index." |
|
|
|
|
|
index = build_faiss_index(chunks) |
|
|
return index, chunks, f"β
Indexed {len(chunks)} chunks. Now ask a question." |
|
|
|
|
|
|
|
|
def answer_question(index, chunks, question): |
|
|
if index is None or chunks is None: |
|
|
return "Upload a PDF first and wait for indexing." |
|
|
if not question or not question.strip(): |
|
|
return "Type a question." |
|
|
|
|
|
hits = retrieve(question, index, chunks, k=TOP_K) |
|
|
context = "\n\n".join([f"[{i+1}] {h[1]}" for i, h in enumerate(hits)]) |
|
|
|
|
|
prompt = f"""You are a helpful assistant. Answer using ONLY the context. |
|
|
If the answer is not in the context, say: "I don't know from the provided document." |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Answer:""" |
|
|
|
|
|
ans = llm_generate(prompt) |
|
|
|
|
|
sources = "\n\n".join( |
|
|
[f"**Source {i+1} (score={hits[i][0]:.3f})**\n{hits[i][1][:700]}..." for i in range(len(hits))] |
|
|
) |
|
|
|
|
|
return f"### Answer\n{ans}\n\n---\n### Retrieved Sources\n{sources}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="PDF RAG (Together.ai)") as demo: |
|
|
gr.Markdown( |
|
|
"# π PDF RAG (Together.ai)\n" |
|
|
"Upload a PDF, build a FAISS index, and ask questions.\n\n" |
|
|
f"**LLM:** `{TOGETHER_MODEL}` \n" |
|
|
f"**Embedder:** `{EMBED_MODEL_NAME}`" |
|
|
) |
|
|
|
|
|
pdf = gr.File(label="Upload PDF", type="filepath") |
|
|
status = gr.Markdown() |
|
|
|
|
|
index_state = gr.State(None) |
|
|
chunks_state = gr.State(None) |
|
|
|
|
|
pdf.change(fn=index_pdf, inputs=[pdf], outputs=[index_state, chunks_state, status]) |
|
|
|
|
|
question = gr.Textbox(label="Question", placeholder="e.g., Summarize the document") |
|
|
out = gr.Markdown() |
|
|
btn = gr.Button("Ask") |
|
|
|
|
|
btn.click(fn=answer_question, inputs=[index_state, chunks_state, question], outputs=[out]) |
|
|
|
|
|
demo.launch() |
|
|
|