Spaces:
Build error
Build error
| # app.py | |
| import os | |
| import io | |
| import tempfile | |
| import streamlit as st | |
| from huggingface_hub import InferenceClient | |
| import pdfplumber | |
| from PIL import Image | |
| import base64 | |
| from typing import Optional | |
| st.set_page_config(page_title="PDF β Summary + TTS + Chat + Diagram", layout="wide") | |
| # ---------- Config (models - change if you prefer others) ---------- | |
| LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF (example) | |
| TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # example TTS model on HF | |
| SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL model on HF | |
| # ---------- Secrets: HF_TOKEN and GROQ_TOKEN ---------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| GROQ_TOKEN = os.environ.get("GROQ_TOKEN") | |
| # ---------- Init InferenceClient ---------- | |
| client: Optional[InferenceClient] = None | |
| client_info = "" | |
| try: | |
| if GROQ_TOKEN: | |
| # Prefer Groq provider if GROQ_TOKEN present | |
| client = InferenceClient(provider="groq", api_key=GROQ_TOKEN) | |
| client_info = "Using Groq provider (GROQ_TOKEN)" | |
| elif HF_TOKEN: | |
| client = InferenceClient(api_key=HF_TOKEN) | |
| client_info = "Using Hugging Face Inference (HF_TOKEN)" | |
| else: | |
| client_info = "NO TOKEN FOUND" | |
| except Exception as e: | |
| client_info = f"Failed to initialize InferenceClient: {e}" | |
| client = None | |
| # ---------- Helpers ---------- | |
| def pdf_to_text_bytes(file_bytes: bytes) -> str: | |
| text_chunks = [] | |
| with pdfplumber.open(io.BytesIO(file_bytes)) as pdf: | |
| for page in pdf.pages: | |
| ptext = page.extract_text() | |
| if ptext: | |
| text_chunks.append(ptext) | |
| return "\n\n".join(text_chunks) | |
| def llama_summarize(text: str) -> str: | |
| if client is None: | |
| raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).") | |
| # Create simple system+user prompt | |
| messages = [ | |
| {"role": "system", "content": "You are a concise summarizer. Provide a short summary in bullet points."}, | |
| {"role": "user", "content": f"Summarize the following document in 6-8 concise bullet points:\n\n{text}"} | |
| ] | |
| # Try chat completions API path, fallback to text generation if necessary | |
| try: | |
| resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages) | |
| return resp.choices[0].message["content"] | |
| except Exception: | |
| try: | |
| # fallback: text generation (single string) | |
| resp2 = client.text_generation(model=LLAMA_MODEL, inputs="Summarize:\n\n" + text, max_new_tokens=512) | |
| # resp2 may be dict-like or object; try a few access patterns | |
| if isinstance(resp2, dict) and "generated_text" in resp2: | |
| return resp2["generated_text"] | |
| # try attribute access | |
| return str(resp2) | |
| except Exception as e: | |
| raise RuntimeError(f"Summarization failed: {e}") | |
| def llama_chat(chat_history: list, user_question: str) -> str: | |
| if client is None: | |
| raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).") | |
| messages = chat_history + [{"role": "user", "content": user_question}] | |
| try: | |
| resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages) | |
| return resp.choices[0].message["content"] | |
| except Exception as e: | |
| raise RuntimeError(f"Chat completion failed: {e}") | |
| def tts_synthesize(text: str) -> bytes: | |
| if client is None: | |
| raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).") | |
| try: | |
| audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text) | |
| return audio_bytes | |
| except Exception as e: | |
| raise RuntimeError(f"TTS failed: {e}") | |
| def generate_image(prompt_text: str) -> Image.Image: | |
| if client is None: | |
| raise RuntimeError("InferenceClient not initialized (missing HF_TOKEN/GROQ_TOKEN).") | |
| try: | |
| img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL) | |
| return Image.open(io.BytesIO(img_bytes)) | |
| except Exception as e: | |
| raise RuntimeError(f"Image generation failed: {e}") | |
| def make_download_link_bytes(data: bytes, filename: str, mime: str): | |
| b64 = base64.b64encode(data).decode() | |
| href = f'<a href="data:{mime};base64,{b64}" download="{filename}">Download {filename}</a>' | |
| return href | |
| # ---------- UI ---------- | |
| st.title("PDF β Summary + TTS + Chat + Diagram (Groq/HF)") | |
| st.sidebar.markdown("### Runtime info") | |
| st.sidebar.write(client_info) | |
| st.sidebar.markdown("**Required env vars**: `HF_TOKEN` and/or `GROQ_TOKEN`. Prefer `GROQ_TOKEN` for Groq provider.") | |
| if client is None: | |
| st.error("Inference client not initialized. Set HF_TOKEN or GROQ_TOKEN as environment variables in your Space.") | |
| st.stop() | |
| uploaded = st.file_uploader("Upload a PDF to analyze", type=["pdf"]) | |
| if uploaded: | |
| file_bytes = uploaded.read() | |
| with st.spinner("Extracting text from PDF..."): | |
| try: | |
| text = pdf_to_text_bytes(file_bytes) | |
| except Exception as e: | |
| st.error(f"Failed to extract text from PDF: {e}") | |
| text = "" | |
| st.subheader("Document preview (first 2000 chars)") | |
| st.text_area("", value=(text[:2000] + ("..." if len(text) > 2000 else "")), height=220) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("Create summary"): | |
| if not text.strip(): | |
| st.error("Document text empty or extraction failed.") | |
| else: | |
| with st.spinner("Summarizing with Llama..."): | |
| try: | |
| summary = llama_summarize(text) | |
| st.session_state["summary"] = summary | |
| st.subheader("Summary") | |
| st.markdown(summary) | |
| except Exception as e: | |
| st.error(str(e)) | |
| if "summary" in st.session_state: | |
| summary = st.session_state["summary"] | |
| if st.button("Synthesize summary to audio"): | |
| with st.spinner("Generating speech..."): | |
| try: | |
| wav = tts_synthesize(summary) | |
| st.audio(wav) | |
| st.markdown(make_download_link_bytes(wav, "summary.wav", "audio/wav"), unsafe_allow_html=True) | |
| except Exception as e: | |
| st.error(str(e)) | |
| with col2: | |
| st.subheader("Chat with the document") | |
| if "chat_history" not in st.session_state: | |
| doc_context = text[:4000] if text else "" | |
| st.session_state["chat_history"] = [ | |
| {"role":"system","content":"You are an assistant that answers questions based only on the provided document context."}, | |
| {"role":"user","content": f"Document context:\n{doc_context}"} | |
| ] | |
| st.session_state["convo_display"] = [] | |
| user_q = st.text_input("Ask a question about the PDF") | |
| if st.button("Ask question") and user_q.strip(): | |
| with st.spinner("Getting answer from Llama..."): | |
| try: | |
| answer = llama_chat(st.session_state["chat_history"], user_q) | |
| # show and store | |
| st.session_state["convo_display"].append(("You", user_q)) | |
| st.session_state["convo_display"].append(("Assistant", answer)) | |
| st.session_state["chat_history"].append({"role":"user","content":user_q}) | |
| st.session_state["chat_history"].append({"role":"assistant","content":answer}) | |
| except Exception as e: | |
| st.error(str(e)) | |
| # show conversation | |
| for speaker, textline in st.session_state.get("convo_display", []): | |
| if speaker == "You": | |
| st.markdown(f"**You:** {textline}") | |
| else: | |
| st.markdown(f"**Assistant:** {textline}") | |
| st.markdown("---") | |
| st.subheader("Generate diagram/image from prompt (SDXL)") | |
| diagram_prompt = st.text_input("Describe the diagram or scene to generate") | |
| if st.button("Generate diagram") and diagram_prompt.strip(): | |
| with st.spinner("Generating image..."): | |
| try: | |
| img = generate_image(diagram_prompt) | |
| st.image(img, use_column_width=True) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png") | |
| except Exception as e: | |
| st.error(str(e)) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### Model IDs (change in app.py if you want)") | |
| st.sidebar.write(f"LLM: {LLAMA_MODEL}") | |
| st.sidebar.write(f"TTS: {TTS_MODEL}") | |
| st.sidebar.write(f"Image: {SDXL_MODEL}") | |