File size: 6,185 Bytes
f33a5dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# app.py
import os
import io
import tempfile
import streamlit as st
from huggingface_hub import InferenceClient
import pdfplumber
from PIL import Image
import base64

# ---------- Configuration ----------
HF_TOKEN = os.environ.get("HF_TOKEN")  # required
GROQ_KEY = os.environ.get("GROQ_API_KEY")  # optional: if you want to call Groq directly
USE_GROQ_PROVIDER = True  # set False to route to default HF provider

# model IDs (change if you prefer other models)
LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use"        # Groq Llama model on HF
TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits"       # a HF-hosted TTS model example
SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"  # SDXL base model

# create Inference client (route via HF token by default)
if USE_GROQ_PROVIDER:
    client = InferenceClient(provider="groq", api_key=HF_TOKEN)
else:
    client = InferenceClient(api_key=HF_TOKEN)

# ---------- Helpers ----------
def pdf_to_text(uploaded_file) -> str:
    text_chunks = []
    with pdfplumber.open(uploaded_file) as pdf:
        for page in pdf.pages:
            ptext = page.extract_text()
            if ptext:
                text_chunks.append(ptext)
    return "\n\n".join(text_chunks)

def llama_summarize(text, max_tokens=512):
    prompt = [
        {"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."},
        {"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"}
    ]
    # Use chat completion endpoint style
    resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt)
    try:
        summary = resp.choices[0].message["content"]
    except Exception:
        # fallback: try text generation field
        summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp)
    return summary

def llama_chat(chat_history, user_question):
    messages = chat_history + [{"role":"user","content":user_question}]
    resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
    return resp.choices[0].message["content"]

def tts_synthesize(text) -> bytes:
    # InferenceClient offers text->audio utilities. This returns raw audio bytes (wav).
    audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
    return audio_bytes

def generate_image(prompt_text) -> Image.Image:
    img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
    return Image.open(io.BytesIO(img_bytes))

def audio_download_button(wav_bytes, filename="summary.wav"):
    b64 = base64.b64encode(wav_bytes).decode()
    href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>'
    st.markdown(href, unsafe_allow_html=True)

# ---------- Streamlit UI ----------
st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide")
st.title("PDF → Summary + Speech + Chat + Diagram (Groq + HF)")

uploaded = st.file_uploader("Upload PDF", type=["pdf"])
if uploaded:
    with st.spinner("Extracting text from PDF..."):
        text = pdf_to_text(uploaded)
    st.subheader("Extracted text (preview)")
    st.text_area("Document text", value=text[:1000], height=200)

    if st.button("Create summary (Groq Llama)"):
        with st.spinner("Summarizing with Groq Llama..."):
            summary = llama_summarize(text)
        st.subheader("Summary")
        st.write(summary)
        st.session_state["summary"] = summary

    if "summary" in st.session_state:
        summary = st.session_state["summary"]
        if st.button("Synthesize audio from summary (TTS)"):
            with st.spinner("Creating audio..."):
                try:
                    audio = tts_synthesize(summary)
                    st.audio(audio)
                    audio_download_button(audio)
                except Exception as e:
                    st.error(f"TTS failed: {e}")

    st.markdown("---")
    st.subheader("Chat with your PDF (ask questions about document)")
    if "chat_history" not in st.session_state:
        # start with system + doc context (shortened)
        doc_context = (text[:4000] + "...") if len(text) > 4000 else text
        st.session_state["chat_history"] = [
            {"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."},
            {"role":"user","content": f"Document context:\n{doc_context}"}
        ]

    user_q = st.text_input("Ask a question about the PDF")
    if st.button("Ask") and user_q:
        with st.spinner("Getting answer from Groq Llama..."):
            answer = llama_chat(st.session_state["chat_history"], user_q)
            st.session_state.setdefault("convo", []).append(("You", user_q))
            st.session_state.setdefault("convo", []).append(("Assistant", answer))
            # append to history for next calls
            st.session_state["chat_history"].append({"role":"user","content":user_q})
            st.session_state["chat_history"].append({"role":"assistant","content":answer})
            st.write(answer)

    st.markdown("---")
    st.subheader("Generate a diagram from your question (SDXL)")
    diagram_prompt = st.text_input("Describe the diagram or scene to generate")
    if st.button("Generate diagram") and diagram_prompt:
        with st.spinner("Generating image (SDXL)..."):
            try:
                img = generate_image(diagram_prompt)
                st.image(img, use_column_width=True)
                # allow download
                buf = io.BytesIO()
                img.save(buf, format="PNG")
                st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
            except Exception as e:
                st.error(f"Image generation failed: {e}")

st.sidebar.title("Settings")
st.sidebar.write("Models in use:")
st.sidebar.write(f"LLM: {LLAMA_MODEL}")
st.sidebar.write(f"TTS: {TTS_MODEL}")
st.sidebar.write(f"Image: {SDXL_MODEL}")

st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.")