kajibuku3 / app.py
Bofandra's picture
Update app.py
b7e000c verified
import os
import gradio as gr
import faiss
import pickle
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
import pdfplumber
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Use FLAN-T5 instead of DeepSeek
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
"""def generate_answer(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=512)
return tokenizer.decode(outputs[0], skip_special_tokens=True)"""
def generate_answer(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.9, # Higher = more creative
repetition_penalty=1.1, # Penalize repeating the same phrases
do_sample=True, # Needed for temperature to work
top_k=50, # Sample from top 50 tokens
top_p=0.95 # Nucleus sampling
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Sentence embeddings
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)
# Save uploaded PDF and index
def save_pdf(file, title):
folder = os.path.join(DATA_DIR, title.strip())
if os.path.exists(folder):
return f"'{title}' already exists. Use a different title."
os.makedirs(folder, exist_ok=True)
chunks = []
page_numbers = []
with pdfplumber.open(file.name) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text()
if text:
for j in range(0, len(text), 500):
chunk = text[j:j+500]
chunks.append(chunk)
page_numbers.append(i + 1)
embeddings = embedder.encode(chunks)
if len(embeddings.shape) != 2:
raise ValueError(f"Expected 2D embeddings, got shape {embeddings.shape}")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
index_path = os.path.join(folder, "index.faiss")
chunks_path = os.path.join(folder, "chunks.pkl")
faiss.write_index(index, index_path)
with open(chunks_path, "wb") as f:
pickle.dump({"chunks": chunks, "page_numbers": page_numbers}, f)
return f"βœ… Saved and indexed '{title}'. You can now ask questions."
def list_titles():
return [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))]
# Ask question using PDF context
def ask_question(message, history, selected_titles):
if not selected_titles:
return "❗ Please select at least one PDF."
combined_answer = ""
for title in selected_titles:
folder = os.path.join(DATA_DIR, title)
try:
index = faiss.read_index(os.path.join(folder, "index.faiss"))
with open(os.path.join(folder, "chunks.pkl"), "rb") as f:
data = pickle.load(f)
chunks = data["chunks"]
page_numbers = data["page_numbers"]
q_embed = embedder.encode([message])
D, I = index.search(q_embed, k=3)
context = "\n".join([
f"(Page {page_numbers[i]}): {chunks[i]}" for i in I[0]
])
#prompt = f"""Answer the question using only the context below.\n\nContext:\n{context}\n\nQuestion: {message}"""
prompt = f"""You are a helpful assistant. Provide a thorough and detailed answer to the following question using only the context.
Context:
{context}
Question: {message}
Answer in detail:
"""
response = generate_answer(prompt)
combined_answer += f"**{title}**:\n{response.strip()}\n\n"
except Exception as e:
combined_answer += f"⚠️ Error with {title}: {str(e)}\n\n"
return combined_answer.strip()
# Gradio UI
with gr.Blocks() as demo:
with gr.Tab("πŸ“„ Upload PDF"):
file = gr.File(label="PDF File", file_types=[".pdf"])
title = gr.Textbox(label="Title for PDF")
upload_btn = gr.Button("Upload and Index")
upload_status = gr.Textbox(label="Status")
upload_btn.click(fn=save_pdf, inputs=[file, title], outputs=upload_status)
with gr.Tab("πŸ’­ Chat with PDFs"):
pdf_selector = gr.CheckboxGroup(label="Select PDFs", choices=list_titles())
refresh_btn = gr.Button("πŸ”„ Refresh PDF List")
refresh_btn.click(fn=list_titles, outputs=pdf_selector)
chat = gr.ChatInterface(fn=ask_question, additional_inputs=[pdf_selector])
demo.launch()