ChatBotsTA commited on
Commit
f33a5dc
·
verified ·
1 Parent(s): bfc1bdb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import tempfile
5
+ import streamlit as st
6
+ from huggingface_hub import InferenceClient
7
+ import pdfplumber
8
+ from PIL import Image
9
+ import base64
10
+
11
+ # ---------- Configuration ----------
12
+ HF_TOKEN = os.environ.get("HF_TOKEN") # required
13
+ GROQ_KEY = os.environ.get("GROQ_API_KEY") # optional: if you want to call Groq directly
14
+ USE_GROQ_PROVIDER = True # set False to route to default HF provider
15
+
16
+ # model IDs (change if you prefer other models)
17
+ LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF
18
+ TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # a HF-hosted TTS model example
19
+ SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL base model
20
+
21
+ # create Inference client (route via HF token by default)
22
+ if USE_GROQ_PROVIDER:
23
+ client = InferenceClient(provider="groq", api_key=HF_TOKEN)
24
+ else:
25
+ client = InferenceClient(api_key=HF_TOKEN)
26
+
27
+ # ---------- Helpers ----------
28
+ def pdf_to_text(uploaded_file) -> str:
29
+ text_chunks = []
30
+ with pdfplumber.open(uploaded_file) as pdf:
31
+ for page in pdf.pages:
32
+ ptext = page.extract_text()
33
+ if ptext:
34
+ text_chunks.append(ptext)
35
+ return "\n\n".join(text_chunks)
36
+
37
+ def llama_summarize(text, max_tokens=512):
38
+ prompt = [
39
+ {"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."},
40
+ {"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"}
41
+ ]
42
+ # Use chat completion endpoint style
43
+ resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt)
44
+ try:
45
+ summary = resp.choices[0].message["content"]
46
+ except Exception:
47
+ # fallback: try text generation field
48
+ summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp)
49
+ return summary
50
+
51
+ def llama_chat(chat_history, user_question):
52
+ messages = chat_history + [{"role":"user","content":user_question}]
53
+ resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages)
54
+ return resp.choices[0].message["content"]
55
+
56
+ def tts_synthesize(text) -> bytes:
57
+ # InferenceClient offers text->audio utilities. This returns raw audio bytes (wav).
58
+ audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text)
59
+ return audio_bytes
60
+
61
+ def generate_image(prompt_text) -> Image.Image:
62
+ img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL)
63
+ return Image.open(io.BytesIO(img_bytes))
64
+
65
+ def audio_download_button(wav_bytes, filename="summary.wav"):
66
+ b64 = base64.b64encode(wav_bytes).decode()
67
+ href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>'
68
+ st.markdown(href, unsafe_allow_html=True)
69
+
70
+ # ---------- Streamlit UI ----------
71
+ st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide")
72
+ st.title("PDF → Summary + Speech + Chat + Diagram (Groq + HF)")
73
+
74
+ uploaded = st.file_uploader("Upload PDF", type=["pdf"])
75
+ if uploaded:
76
+ with st.spinner("Extracting text from PDF..."):
77
+ text = pdf_to_text(uploaded)
78
+ st.subheader("Extracted text (preview)")
79
+ st.text_area("Document text", value=text[:1000], height=200)
80
+
81
+ if st.button("Create summary (Groq Llama)"):
82
+ with st.spinner("Summarizing with Groq Llama..."):
83
+ summary = llama_summarize(text)
84
+ st.subheader("Summary")
85
+ st.write(summary)
86
+ st.session_state["summary"] = summary
87
+
88
+ if "summary" in st.session_state:
89
+ summary = st.session_state["summary"]
90
+ if st.button("Synthesize audio from summary (TTS)"):
91
+ with st.spinner("Creating audio..."):
92
+ try:
93
+ audio = tts_synthesize(summary)
94
+ st.audio(audio)
95
+ audio_download_button(audio)
96
+ except Exception as e:
97
+ st.error(f"TTS failed: {e}")
98
+
99
+ st.markdown("---")
100
+ st.subheader("Chat with your PDF (ask questions about document)")
101
+ if "chat_history" not in st.session_state:
102
+ # start with system + doc context (shortened)
103
+ doc_context = (text[:4000] + "...") if len(text) > 4000 else text
104
+ st.session_state["chat_history"] = [
105
+ {"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."},
106
+ {"role":"user","content": f"Document context:\n{doc_context}"}
107
+ ]
108
+
109
+ user_q = st.text_input("Ask a question about the PDF")
110
+ if st.button("Ask") and user_q:
111
+ with st.spinner("Getting answer from Groq Llama..."):
112
+ answer = llama_chat(st.session_state["chat_history"], user_q)
113
+ st.session_state.setdefault("convo", []).append(("You", user_q))
114
+ st.session_state.setdefault("convo", []).append(("Assistant", answer))
115
+ # append to history for next calls
116
+ st.session_state["chat_history"].append({"role":"user","content":user_q})
117
+ st.session_state["chat_history"].append({"role":"assistant","content":answer})
118
+ st.write(answer)
119
+
120
+ st.markdown("---")
121
+ st.subheader("Generate a diagram from your question (SDXL)")
122
+ diagram_prompt = st.text_input("Describe the diagram or scene to generate")
123
+ if st.button("Generate diagram") and diagram_prompt:
124
+ with st.spinner("Generating image (SDXL)..."):
125
+ try:
126
+ img = generate_image(diagram_prompt)
127
+ st.image(img, use_column_width=True)
128
+ # allow download
129
+ buf = io.BytesIO()
130
+ img.save(buf, format="PNG")
131
+ st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png")
132
+ except Exception as e:
133
+ st.error(f"Image generation failed: {e}")
134
+
135
+ st.sidebar.title("Settings")
136
+ st.sidebar.write("Models in use:")
137
+ st.sidebar.write(f"LLM: {LLAMA_MODEL}")
138
+ st.sidebar.write(f"TTS: {TTS_MODEL}")
139
+ st.sidebar.write(f"Image: {SDXL_MODEL}")
140
+
141
+ st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.")