StudySense / app.py
npaleti2002's picture
Create app.py
7761e22 verified
import re
from pathlib import Path
import gradio as gr
import numpy as np
import pdfplumber
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import pipeline
# ---------- Models ----------
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
# ---------- Global state (will be stored in gr.State) ----------
# lecture_chunks, vectorizer, X_matrix will live in state
# ---------- Helpers ----------
def load_text_from_file(file_obj) -> str:
if file_obj is None:
return ""
path = Path(file_obj.name)
suffix = path.suffix.lower()
if suffix == ".pdf":
texts = []
with pdfplumber.open(file_obj) as pdf:
for page in pdf.pages:
page_text = page.extract_text() or ""
texts.append(page_text)
raw_text = "\n".join(texts)
elif suffix == ".txt":
raw_text = file_obj.read().decode("utf-8", errors="ignore")
else:
raise ValueError("Only .pdf and .txt files are supported.")
return clean_text(raw_text)
def clean_text(text: str) -> str:
text = text.replace("\r", " ")
text = re.sub(r"\n+", "\n", text)
text = re.sub(r"[ \t]+", " ", text)
return text.strip()
def chunk_text(text: str, chunk_words: int = 350, overlap_words: int = 50):
words = text.split()
chunks = []
start = 0
chunk_id = 1
while start < len(words):
end = start + chunk_words
chunk_words_list = words[start:end]
chunk_text_ = " ".join(chunk_words_list)
chunks.append(
{
"chunk_id": f"C{chunk_id}",
"text": chunk_text_,
}
)
chunk_id += 1
start = end - overlap_words
return chunks
def build_retriever(chunks):
docs = [c["text"] for c in chunks]
vectorizer = TfidfVectorizer(
max_features=10000,
ngram_range=(1, 2),
min_df=1,
)
X = vectorizer.fit_transform(docs)
return vectorizer, X
def generate_summary(text: str, max_words: int = 300) -> str:
if not text:
return "No text found in the uploaded file."
# Hugging Face summarization has a max token limit; we slice text roughly
# into smaller windows and summarize each, then summarize again.
# Keep it simple & fast.
max_chunk_chars = 2500
windows = []
start = 0
while start < len(text):
end = start + max_chunk_chars
windows.append(text[start:end])
start = end
partial_summaries = []
for w in windows[:3]: # hard cap, don’t explode runtime
s = summarizer(
w,
max_length=180,
min_length=60,
do_sample=False,
truncation=True,
)[0]["summary_text"]
partial_summaries.append(s)
combined = " ".join(partial_summaries)
final = summarizer(
combined,
max_length=220,
min_length=80,
do_sample=False,
truncation=True,
)[0]["summary_text"]
return final
def retrieve_chunks(question, chunks, vectorizer, X, top_k: int = 5):
if not chunks or vectorizer is None or X is None:
return []
q_vec = vectorizer.transform([question])
sims = cosine_similarity(q_vec, X)[0]
top_idx = np.argsort(-sims)[:top_k]
results = []
for rank, idx in enumerate(top_idx, start=1):
c = chunks[idx]
results.append(
{
"rank": rank,
"chunk_id": c["chunk_id"],
"text": c["text"],
"similarity": float(sims[idx]),
}
)
return results
def answer_question(question, chunks, vectorizer, X):
if not question.strip():
return "Please enter a question.", ""
retrieved = retrieve_chunks(question, chunks, vectorizer, X, top_k=3)
if not retrieved:
return "Please upload and process a lecture first.", ""
context_text = "\n\n".join([r["text"] for r in retrieved])
try:
ans = qa_pipeline(
{
"question": question,
"context": context_text,
}
)
answer = ans.get("answer", "").strip()
except Exception as e:
answer = f"Error from QA model: {e}"
# Build a short “sources” string
source_info = "; ".join(
[f"{r['chunk_id']} (sim={r['similarity']:.3f})" for r in retrieved]
)
return answer, source_info
# ---------- Gradio Callbacks ----------
def process_lecture(file):
"""
1. Read PDF/TXT
2. Chunk
3. Build retriever
4. Generate summary
Returns: summary, chunks, vectorizer, X
"""
if file is None:
return "Please upload a lecture file.", [], None, None
try:
text = load_text_from_file(file)
except Exception as e:
return f"Error reading file: {e}", [], None, None
if len(text) < 100:
return "File text is too short or empty after extraction.", [], None, None
chunks = chunk_text(text, chunk_words=350, overlap_words=50)
vectorizer, X = build_retriever(chunks)
summary = generate_summary(text)
return summary, chunks, vectorizer, X
def chat_fn(question, chunks, vectorizer, X):
answer, sources = answer_question(question, chunks, vectorizer, X)
if sources:
answer = f"{answer}\n\n_Sources: {sources}_"
return answer
# ---------- Gradio UI ----------
with gr.Blocks() as demo:
gr.Markdown("# 📚 Lecture Summarizer + Chatbot\nUpload a PDF/TXT lecture, get a summary, then ask questions about it.")
with gr.Row():
file_input = gr.File(label="Upload lecture (.pdf or .txt)")
process_btn = gr.Button("Process Lecture")
summary_box = gr.Textbox(
label="Lecture Summary",
lines=12,
interactive=False,
)
# State: saved across chat turns
chunks_state = gr.State([])
vectorizer_state = gr.State(None)
X_state = gr.State(None)
process_btn.click(
fn=process_lecture,
inputs=[file_input],
outputs=[summary_box, chunks_state, vectorizer_state, X_state],
)
gr.Markdown("## 💬 Chat with the Lecture")
with gr.Row():
question_box = gr.Textbox(label="Your Question")
answer_box = gr.Textbox(label="Answer", lines=6, interactive=False)
ask_btn = gr.Button("Ask")
ask_btn.click(
fn=chat_fn,
inputs=[question_box, chunks_state, vectorizer_state, X_state],
outputs=[answer_box],
)
if __name__ == "__main__":
demo.launch()