pdf-summarize / app.py
ChatBotsTA's picture
Update app.py
cf0600b verified
# app.py
import os
import io
import streamlit as st
from huggingface_hub import InferenceClient
import pdfplumber
from PIL import Image
import base64
from typing import Optional
# ----------------- CONFIG -----------------
LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use"
TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits"
SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
HF_TOKEN = os.environ.get("HF_TOKEN")
GROQ_TOKEN = os.environ.get("GROQ_TOKEN")
# Prefer Groq if token present, otherwise HF token
client: Optional[InferenceClient] = None
try:
if GROQ_TOKEN:
client = InferenceClient(provider="groq", api_key=GROQ_TOKEN)
elif HF_TOKEN:
client = InferenceClient(api_key=HF_TOKEN)
except Exception:
client = None
# ----------------- PAGE STYLE -----------------
st.set_page_config(page_title="PDF Buddy β€” Summarize β€’ Speak β€’ Chat β€’ Draw", layout="wide")
st.markdown(
"""
<style>
.main > .block-container { padding: 1.5rem 2rem; max-width: 1100px; }
.title { font-size:28px; font-weight:700; color:#0f172a; }
.subtitle { color:#6b7280; margin-bottom:12px; }
.big-btn { font-weight:600; padding:10px 18px; border-radius:10px; }
.small-muted { color:#9ca3af; font-size:12px; }
</style>
""",
unsafe_allow_html=True,
)
st.markdown('<div class="title">πŸ“„ PDF Buddy β€” Summarize β€’ Speak β€’ Chat β€’ Draw</div>', unsafe_allow_html=True)
st.markdown('<div class="subtitle">Upload a PDF, get a concise summary, speak it, ask questions, or generate diagrams from prompts.</div>', unsafe_allow_html=True)
# ----------------- FUNCTIONS -----------------
def pdf_to_text_bytes(file_bytes: bytes):
"""Extract text using pdfplumber, return full text and page count."""
text_chunks = []
try:
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
total = len(pdf.pages)
for i, page in enumerate(pdf.pages):
ptext = page.extract_text() or ""
text_chunks.append(ptext)
# simple progress output handled by caller
except Exception as e:
raise RuntimeError(f"PDF parsing failed: {e}")
return "\n\n".join(text_chunks), total
def llama_summarize(text: str) -> str:
if client is None:
raise RuntimeError("LLM client not initialized (missing HF_TOKEN/GROQ_TOKEN).")
messages = [
{"role": "system", "content": "You are a concise summarizer. Give 6 short bullet points."},
{"role": "user", "content": f"Summarize this document in 6 concise bullet points:\n\n{text}"}
]
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
return resp.choices[0].message["content"]
def llama_chat(chat_history: list, user_question: str) -> str:
if client is None:
raise RuntimeError("LLM client not initialized (missing HF_TOKEN/GROQ_TOKEN).")
messages = chat_history + [{"role": "user", "content": user_question}]
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
return resp.choices[0].message["content"]
def tts_synthesize(text: str) -> bytes:
if client is None:
raise RuntimeError("TTS client not initialized (missing HF_TOKEN/GROQ_TOKEN).")
audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
return audio_bytes
def generate_image(prompt_text: str) -> Image.Image:
if client is None:
raise RuntimeError("Image generation client not initialized (missing HF_TOKEN/GROQ_TOKEN).")
img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
return Image.open(io.BytesIO(img_bytes))
def make_download_link_bytes(data: bytes, filename: str, mime: str):
b64 = base64.b64encode(data).decode()
href = f'<a href="data:{mime};base64,{b64}" download="{filename}">⬇️ Download {filename}</a>'
return href
# ----------------- STATE -----------------
if "uploaded_name" not in st.session_state:
st.session_state.uploaded_name = None
if "extracted_text" not in st.session_state:
st.session_state.extracted_text = ""
if "summary" not in st.session_state:
st.session_state.summary = ""
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# ----------------- Uploader Column -----------------
col_left, col_right = st.columns([1, 1])
with col_left:
uploaded = st.file_uploader("Upload PDF (single file)", type=["pdf"], help="Drag & drop or click to choose a PDF.")
if uploaded is not None:
# immediate feedback to user
st.success(f"Uploaded file: **{uploaded.name}** β€” {round(len(uploaded.getvalue())/1024,1)} KB")
st.session_state.uploaded_name = uploaded.name
# extract text with progress
with st.spinner("Extracting text from PDF..."):
try:
bytes_in = uploaded.getvalue()
text, pages = pdf_to_text_bytes(bytes_in)
st.session_state.extracted_text = text
st.success(f"Extraction complete β€” {pages} pages processed. Preview shown below.")
except Exception as e:
st.session_state.extracted_text = ""
st.error(f"Failed to extract PDF text: {e}")
# show a preview (or hint)
if st.session_state.extracted_text:
st.subheader("Document preview (first 3000 chars)")
st.text_area("", value=(st.session_state.extracted_text[:3000] + ("..." if len(st.session_state.extracted_text) > 3000 else "")), height=240)
else:
st.info("No document loaded. Upload a PDF to get started. If your file is large, extraction may take a few seconds.")
with col_right:
# Controls: disabled until extraction is available
disabled = not bool(st.session_state.extracted_text)
st.subheader("Actions")
if st.button("πŸ“ Create summary", key="summarize", disabled=disabled):
with st.spinner("Creating summary..."):
try:
summary = llama_summarize(st.session_state.extracted_text[:30000]) # limit prompt length
st.session_state.summary = summary
st.success("Summary created.")
except Exception as e:
st.error(f"Summarization failed: {e}")
if st.session_state.summary:
st.markdown("**Summary:**")
st.markdown(st.session_state.summary)
if st.button("πŸ”Š Synthesize summary to audio", key="tts", disabled=disabled or not st.session_state.summary):
with st.spinner("Synthesizing audio..."):
try:
wav = tts_synthesize(st.session_state.summary)
st.audio(wav)
st.markdown(make_download_link_bytes(wav, "summary.wav", "audio/wav"), unsafe_allow_html=True)
except Exception as e:
st.error(f"TTS failed: {e}")
st.markdown("---")
st.subheader("Chat with document")
if "chat_history" not in st.session_state or not st.session_state.chat_history:
# initialize with document context (short)
context = st.session_state.extracted_text[:4000] if st.session_state.extracted_text else ""
st.session_state.chat_history = [
{"role": "system", "content": "You are a helpful assistant. Answer strictly using the document context."},
{"role": "user", "content": f"Document context:\n{context}"}
]
user_q = st.text_input("Ask a question about the PDF", key="user_q", disabled=disabled)
if st.button("❓ Ask", key="ask_btn", disabled=disabled or not user_q):
with st.spinner("Getting answer..."):
try:
ans = llama_chat(st.session_state.chat_history, user_q)
st.session_state.chat_history.append({"role": "user", "content": user_q})
st.session_state.chat_history.append({"role": "assistant", "content": ans})
st.markdown(f"**You:** {user_q}")
st.markdown(f"**Assistant:** {ans}")
except Exception as e:
st.error(f"Chat failed: {e}")
st.markdown("---")
st.subheader("Generate diagram from prompt (SDXL)")
diagram_prompt = st.text_input("Describe diagram or scene", key="diagram_prompt", disabled=disabled)
if st.button("πŸ–ΌοΈ Generate diagram", key="gen_img", disabled=disabled or not diagram_prompt):
with st.spinner("Generating image..."):
try:
img = generate_image(diagram_prompt)
st.image(img, use_column_width=True)
buf = io.BytesIO()
img.save(buf, format="PNG")
st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
except Exception as e:
st.error(f"Image generation failed: {e}")
# ----------------- FOOTER / NOTES -----------------
st.markdown("---")
st.markdown(
"""
**Notes**
- API keys are read from environment variables (HF_TOKEN and/or GROQ_TOKEN). They are NOT displayed here.
- If nothing happens after upload, try a small PDF (1–2 pages) to test extraction first.
- If you get errors about the LLM/TTS/Image calls, confirm the tokens are set in your Space settings or `.env` (don’t commit `.env` publicly).
"""
)