pdf-summarize / app.py
ChatBotsTA's picture
Update app.py
8945838 verified
raw
history blame
8.76 kB
# app.py
import os
import io
import tempfile
import streamlit as st
from huggingface_hub import InferenceClient
import pdfplumber
from PIL import Image
import base64
from typing import Optional
st.set_page_config(page_title="PDF β†’ Summary + TTS + Chat + Diagram", layout="wide")
# ---------- Config (models - change if you prefer others) ----------
LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF (example)
TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # example TTS model on HF
SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL model on HF
# ---------- Secrets: HF_TOKEN and GROQ_TOKEN ----------
HF_TOKEN = os.environ.get("HF_TOKEN")
GROQ_TOKEN = os.environ.get("GROQ_TOKEN")
# ---------- Init InferenceClient ----------
client: Optional[InferenceClient] = None
client_info = ""
try:
if GROQ_TOKEN:
# Prefer Groq provider if GROQ_TOKEN present
client = InferenceClient(provider="groq", api_key=GROQ_TOKEN)
client_info = "Using Groq provider (GROQ_TOKEN)"
elif HF_TOKEN:
client = InferenceClient(api_key=HF_TOKEN)
client_info = "Using Hugging Face Inference (HF_TOKEN)"
else:
client_info = "NO TOKEN FOUND"
except Exception as e:
client_info = f"Failed to initialize InferenceClient: {e}"
client = None
# ---------- Helpers ----------
def pdf_to_text_bytes(file_bytes: bytes) -> str:
text_chunks = []
with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
for page in pdf.pages:
ptext = page.extract_text()
if ptext:
text_chunks.append(ptext)
return "\n\n".join(text_chunks)
def llama_summarize(text: str) -> str:
if client is None:
raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
# Create simple system+user prompt
messages = [
{"role": "system", "content": "You are a concise summarizer. Provide a short summary in bullet points."},
{"role": "user", "content": f"Summarize the following document in 6-8 concise bullet points:\n\n{text}"}
]
# Try chat completions API path, fallback to text generation if necessary
try:
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
return resp.choices[0].message["content"]
except Exception:
try:
# fallback: text generation (single string)
resp2 = client.text_generation(model=LLAMA_MODEL, inputs="Summarize:\n\n" + text, max_new_tokens=512)
# resp2 may be dict-like or object; try a few access patterns
if isinstance(resp2, dict) and "generated_text" in resp2:
return resp2["generated_text"]
# try attribute access
return str(resp2)
except Exception as e:
raise RuntimeError(f"Summarization failed: {e}")
def llama_chat(chat_history: list, user_question: str) -> str:
if client is None:
raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
messages = chat_history + [{"role": "user", "content": user_question}]
try:
resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
return resp.choices[0].message["content"]
except Exception as e:
raise RuntimeError(f"Chat completion failed: {e}")
def tts_synthesize(text: str) -> bytes:
if client is None:
raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
try:
audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
return audio_bytes
except Exception as e:
raise RuntimeError(f"TTS failed: {e}")
def generate_image(prompt_text: str) -> Image.Image:
if client is None:
raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).")
try:
img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
return Image.open(io.BytesIO(img_bytes))
except Exception as e:
raise RuntimeError(f"Image generation failed: {e}")
def make_download_link_bytes(data: bytes, filename: str, mime: str):
b64 = base64.b64encode(data).decode()
href = f'<a href="data:{mime};base64,{b64}" download="{filename}">Download {filename}</a>'
return href
# ---------- UI ----------
st.title("PDF β†’ Summary + TTS + Chat + Diagram (Groq/HF)")
st.sidebar.markdown("### Runtime info")
st.sidebar.write(client_info)
st.sidebar.markdown("**Required env vars**: `HF_TOKEN` and/or `GROQ_TOKEN`. Prefer `GROQ_TOKEN` for Groq provider.")
if client is None:
st.error("Inference client not initialized. Set HF_TOKEN or GROQ_TOKEN as environment variables in your Space.")
st.stop()
uploaded = st.file_uploader("Upload a PDF to analyze", type=["pdf"])
if uploaded:
file_bytes = uploaded.read()
with st.spinner("Extracting text from PDF..."):
try:
text = pdf_to_text_bytes(file_bytes)
except Exception as e:
st.error(f"Failed to extract text from PDF: {e}")
text = ""
st.subheader("Document preview (first 2000 chars)")
st.text_area("", value=(text[:2000] + ("..." if len(text) > 2000 else "")), height=220)
col1, col2 = st.columns(2)
with col1:
if st.button("Create summary"):
if not text.strip():
st.error("Document text empty or extraction failed.")
else:
with st.spinner("Summarizing with Llama..."):
try:
summary = llama_summarize(text)
st.session_state["summary"] = summary
st.subheader("Summary")
st.markdown(summary)
except Exception as e:
st.error(str(e))
if "summary" in st.session_state:
summary = st.session_state["summary"]
if st.button("Synthesize summary to audio"):
with st.spinner("Generating speech..."):
try:
wav = tts_synthesize(summary)
st.audio(wav)
st.markdown(make_download_link_bytes(wav, "summary.wav", "audio/wav"), unsafe_allow_html=True)
except Exception as e:
st.error(str(e))
with col2:
st.subheader("Chat with the document")
if "chat_history" not in st.session_state:
doc_context = text[:4000] if text else ""
st.session_state["chat_history"] = [
{"role":"system","content":"You are an assistant that answers questions based only on the provided document context."},
{"role":"user","content": f"Document context:\n{doc_context}"}
]
st.session_state["convo_display"] = []
user_q = st.text_input("Ask a question about the PDF")
if st.button("Ask question") and user_q.strip():
with st.spinner("Getting answer from Llama..."):
try:
answer = llama_chat(st.session_state["chat_history"], user_q)
# show and store
st.session_state["convo_display"].append(("You", user_q))
st.session_state["convo_display"].append(("Assistant", answer))
st.session_state["chat_history"].append({"role":"user","content":user_q})
st.session_state["chat_history"].append({"role":"assistant","content":answer})
except Exception as e:
st.error(str(e))
# show conversation
for speaker, textline in st.session_state.get("convo_display", []):
if speaker == "You":
st.markdown(f"**You:** {textline}")
else:
st.markdown(f"**Assistant:** {textline}")
st.markdown("---")
st.subheader("Generate diagram/image from prompt (SDXL)")
diagram_prompt = st.text_input("Describe the diagram or scene to generate")
if st.button("Generate diagram") and diagram_prompt.strip():
with st.spinner("Generating image..."):
try:
img = generate_image(diagram_prompt)
st.image(img, use_column_width=True)
buf = io.BytesIO()
img.save(buf, format="PNG")
st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
except Exception as e:
st.error(str(e))
st.sidebar.markdown("---")
st.sidebar.markdown("### Model IDs (change in app.py if you want)")
st.sidebar.write(f"LLM: {LLAMA_MODEL}")
st.sidebar.write(f"TTS: {TTS_MODEL}")
st.sidebar.write(f"Image: {SDXL_MODEL}")