Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,10 +9,10 @@ from typing import List, Optional
|
|
| 9 |
|
| 10 |
# ============ CONFIG =============
|
| 11 |
OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY")
|
| 12 |
-
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "gpt-4o-mini")
|
| 13 |
ELEVEN_API_KEY = os.getenv("ELEVEN_API_KEY")
|
| 14 |
-
HUGGINGFACE_KEY = os.getenv("HUGGINGFACE_API_KEY")
|
| 15 |
-
HF_MERMAID_MODEL = os.getenv("HF_MERMAID_MODEL", "TroyDoesAI/MermaidStable3B")
|
| 16 |
|
| 17 |
# ============ HELPERS ============
|
| 18 |
def clean_text(text: str) -> str:
|
|
@@ -38,59 +38,43 @@ def chunk_text_by_chars(text: str, chunk_size: int = 3000, overlap: int = 200) -
|
|
| 38 |
start = max(end - overlap, end)
|
| 39 |
return chunks
|
| 40 |
|
| 41 |
-
# ---------- OpenRouter chat
|
| 42 |
def openrouter_chat(messages: List[dict], model: str = OPENROUTER_MODEL, max_tokens: int = 800, temperature: float = 0.2) -> str:
|
| 43 |
-
"""
|
| 44 |
-
Send messages (OpenAI-style) to OpenRouter's chat completions endpoint.
|
| 45 |
-
Requires OPENROUTER_API_KEY in ENV.
|
| 46 |
-
"""
|
| 47 |
if not OPENROUTER_KEY:
|
| 48 |
raise RuntimeError("OPENROUTER_API_KEY not set")
|
| 49 |
-
|
| 50 |
url = "https://api.openrouter.ai/v1/chat/completions"
|
| 51 |
headers = {"Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json"}
|
| 52 |
-
payload = {
|
| 53 |
-
"model": model,
|
| 54 |
-
"messages": messages,
|
| 55 |
-
"max_tokens": max_tokens,
|
| 56 |
-
"temperature": temperature,
|
| 57 |
-
}
|
| 58 |
resp = requests.post(url, json=payload, headers=headers, timeout=60)
|
| 59 |
try:
|
| 60 |
resp.raise_for_status()
|
| 61 |
except Exception as e:
|
| 62 |
raise RuntimeError(f"OpenRouter API error: {resp.status_code} {resp.text}") from e
|
| 63 |
-
|
| 64 |
data = resp.json()
|
| 65 |
-
# robustly extract text
|
| 66 |
text = ""
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# OpenRouter
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
text =
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
return text or ""
|
| 79 |
|
| 80 |
def ask_model_for_summary(text: str) -> str:
|
| 81 |
prompt = f"Summarize the following text clearly and concisely (bullet points, 5-8 bullets max):\n\n{text}"
|
| 82 |
-
messages = [
|
| 83 |
-
{"role": "system", "content": "You are a concise summarizer."},
|
| 84 |
-
{"role": "user", "content": prompt},
|
| 85 |
-
]
|
| 86 |
return openrouter_chat(messages, max_tokens=400)
|
| 87 |
|
| 88 |
def ask_model_question(question: str, context: str) -> str:
|
| 89 |
prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer in a concise helpful way."
|
| 90 |
-
messages = [
|
| 91 |
-
{"role": "system", "content": "You are a helpful assistant."},
|
| 92 |
-
{"role": "user", "content": prompt},
|
| 93 |
-
]
|
| 94 |
return openrouter_chat(messages, max_tokens=600)
|
| 95 |
|
| 96 |
# ---------- ElevenLabs TTS ----------
|
|
@@ -107,12 +91,8 @@ def text_to_speech_eleven(text: str, voice_id: str = "pnCWbS8Aqipqqr5wzjuy") ->
|
|
| 107 |
st.warning(f"ElevenLabs TTS failed: {r.status_code} {r.text[:300]}")
|
| 108 |
return None
|
| 109 |
|
| 110 |
-
# ----------
|
| 111 |
def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]:
|
| 112 |
-
"""
|
| 113 |
-
If HUGGINGFACE_KEY is set, call Hugging Face Inference API for model that outputs Mermaid or Mermaid-like code.
|
| 114 |
-
Many community models/Spaces are simple text-output LLMs that can be prompted to return mermaid code.
|
| 115 |
-
"""
|
| 116 |
if not HUGGINGFACE_KEY:
|
| 117 |
return None
|
| 118 |
url = f"https://api-inference.huggingface.co/models/{model}"
|
|
@@ -123,16 +103,13 @@ def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]
|
|
| 123 |
st.warning(f"Hugging Face model call failed: {r.status_code} {r.text[:300]}")
|
| 124 |
return None
|
| 125 |
j = r.json()
|
| 126 |
-
# shape varies by model; try to extract text
|
| 127 |
if isinstance(j, dict) and "error" in j:
|
| 128 |
st.warning(f"Hugging Face error: {j['error']}")
|
| 129 |
return None
|
| 130 |
if isinstance(j, list) and len(j) > 0 and isinstance(j[0], dict) and "generated_text" in j[0]:
|
| 131 |
return j[0]["generated_text"]
|
| 132 |
-
# some models return plain text in str
|
| 133 |
if isinstance(j, str):
|
| 134 |
return j
|
| 135 |
-
# fallback: try to get 'output' key
|
| 136 |
if isinstance(j, dict):
|
| 137 |
for k in ("generated_text", "output", "text"):
|
| 138 |
if k in j:
|
|
@@ -140,47 +117,32 @@ def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]
|
|
| 140 |
return None
|
| 141 |
|
| 142 |
def generate_mermaid_from_summary(summary: str) -> str:
|
| 143 |
-
""
|
| 144 |
-
|
| 145 |
-
We'll create a simple flow: split summary into sentences / bullets and link them sequentially.
|
| 146 |
-
"""
|
| 147 |
-
# first try HF
|
| 148 |
-
prompt = (
|
| 149 |
-
"Given the following concise summary, produce a Mermaid flowchart (use 'graph TD' or 'flowchart TD' syntax). "
|
| 150 |
-
"Output only the Mermaid code block (no extra explanation). Summary:\n\n" + summary
|
| 151 |
-
)
|
| 152 |
hf_output = call_hf_mermaid(prompt)
|
| 153 |
if hf_output:
|
| 154 |
-
# try to extract just the mermaid text
|
| 155 |
-
# if the model wrapped in ```mermaid ... ``` try to strip
|
| 156 |
m = re.search(r"```(?:mermaid)?\n([\s\S]+?)```", hf_output, re.IGNORECASE)
|
| 157 |
if m:
|
| 158 |
return m.group(1).strip()
|
| 159 |
return hf_output.strip()
|
| 160 |
|
| 161 |
-
# fallback
|
| 162 |
-
# split by bullet/newline or sentences
|
| 163 |
lines = re.split(r"\n+|-{1,}\s*|•\s*", summary)
|
| 164 |
nodes = [clean_text(l) for l in lines if clean_text(l)]
|
| 165 |
-
# keep a reasonable number
|
| 166 |
nodes = nodes[:8]
|
| 167 |
if not nodes:
|
| 168 |
nodes = ["Summary empty"]
|
| 169 |
mermaid = "flowchart TD\n"
|
| 170 |
-
# create nodes with safe ids
|
| 171 |
for i, n in enumerate(nodes):
|
| 172 |
-
#
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
for i in range(len(nodes) - 1):
|
| 175 |
mermaid += f" A{i} --> A{i+1}\n"
|
| 176 |
return mermaid
|
| 177 |
|
| 178 |
-
# ---------- Render mermaid in browser ----------
|
| 179 |
def render_mermaid(mermaid_code: str, height: int = 400):
|
| 180 |
-
"""
|
| 181 |
-
Render Mermaid chart client-side using mermaid.js in an HTML component.
|
| 182 |
-
"""
|
| 183 |
-
# wrap in HTML that loads mermaid CDN
|
| 184 |
html = f"""
|
| 185 |
<div id="mermaid-target">
|
| 186 |
<pre class="mermaid">
|
|
@@ -192,13 +154,13 @@ def render_mermaid(mermaid_code: str, height: int = 400):
|
|
| 192 |
mermaid.initialize({{startOnLoad:true}});
|
| 193 |
</script>
|
| 194 |
"""
|
|
|
|
| 195 |
st.components.v1.html(html, height=height, scrolling=True)
|
| 196 |
|
| 197 |
# ============ STREAMLIT UI ============
|
| 198 |
st.set_page_config(page_title="PDF Q&A + Summary Diagram", layout="wide")
|
| 199 |
st.title("📄 PDF Q&A + Summary Diagram + Audio")
|
| 200 |
|
| 201 |
-
# API status
|
| 202 |
c1, c2, c3 = st.columns(3)
|
| 203 |
with c1:
|
| 204 |
if OPENROUTER_KEY:
|
|
@@ -232,31 +194,26 @@ if uploaded_file:
|
|
| 232 |
if st.button("Summarize and generate diagram"):
|
| 233 |
try:
|
| 234 |
with st.spinner("Summarizing with OpenRouter..."):
|
| 235 |
-
# limit to avoid huge inputs
|
| 236 |
to_sum = raw_text[:15000]
|
| 237 |
summary = ask_model_for_summary(to_sum)
|
| 238 |
st.subheader("📌 Summary")
|
| 239 |
st.write(summary)
|
| 240 |
|
| 241 |
-
# TTS summary
|
| 242 |
audio = text_to_speech_eleven(summary)
|
| 243 |
if audio:
|
| 244 |
st.audio(audio, format="audio/mp3")
|
| 245 |
elif not ELEVEN_API_KEY:
|
| 246 |
st.info("TTS not available (ELEVEN_API_KEY missing).")
|
| 247 |
|
| 248 |
-
# produce mermaid
|
| 249 |
mermaid_code = generate_mermaid_from_summary(summary)
|
| 250 |
st.subheader("🗺️ Summary Diagram (Mermaid)")
|
| 251 |
render_mermaid(mermaid_code, height=480)
|
| 252 |
-
# also show the raw mermaid code for copy/paste
|
| 253 |
st.markdown("**Mermaid code (copy/paste):**")
|
| 254 |
st.code(mermaid_code, language="mermaid")
|
| 255 |
|
| 256 |
except Exception as e:
|
| 257 |
st.error(f"Summarize/diagram generation failed: {e}")
|
| 258 |
|
| 259 |
-
# Q&A box
|
| 260 |
query = st.text_input("Ask a question about the PDF (use Enter):")
|
| 261 |
if query:
|
| 262 |
if not OPENROUTER_KEY:
|
|
@@ -266,7 +223,7 @@ if uploaded_file:
|
|
| 266 |
with st.spinner("Answering via OpenRouter..."):
|
| 267 |
chunks = chunk_text_by_chars(raw_text, chunk_size=3000, overlap=200)
|
| 268 |
answers = []
|
| 269 |
-
for c in chunks[:3]:
|
| 270 |
a = ask_model_question(query, c)
|
| 271 |
if a:
|
| 272 |
answers.append(a)
|
|
|
|
| 9 |
|
| 10 |
# ============ CONFIG =============
|
| 11 |
OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY")
|
| 12 |
+
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "gpt-4o-mini")
|
| 13 |
ELEVEN_API_KEY = os.getenv("ELEVEN_API_KEY")
|
| 14 |
+
HUGGINGFACE_KEY = os.getenv("HUGGINGFACE_API_KEY")
|
| 15 |
+
HF_MERMAID_MODEL = os.getenv("HF_MERMAID_MODEL", "TroyDoesAI/MermaidStable3B")
|
| 16 |
|
| 17 |
# ============ HELPERS ============
|
| 18 |
def clean_text(text: str) -> str:
|
|
|
|
| 38 |
start = max(end - overlap, end)
|
| 39 |
return chunks
|
| 40 |
|
| 41 |
+
# ---------- OpenRouter chat ----------
|
| 42 |
def openrouter_chat(messages: List[dict], model: str = OPENROUTER_MODEL, max_tokens: int = 800, temperature: float = 0.2) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if not OPENROUTER_KEY:
|
| 44 |
raise RuntimeError("OPENROUTER_API_KEY not set")
|
|
|
|
| 45 |
url = "https://api.openrouter.ai/v1/chat/completions"
|
| 46 |
headers = {"Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json"}
|
| 47 |
+
payload = {"model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
resp = requests.post(url, json=payload, headers=headers, timeout=60)
|
| 49 |
try:
|
| 50 |
resp.raise_for_status()
|
| 51 |
except Exception as e:
|
| 52 |
raise RuntimeError(f"OpenRouter API error: {resp.status_code} {resp.text}") from e
|
|
|
|
| 53 |
data = resp.json()
|
|
|
|
| 54 |
text = ""
|
| 55 |
+
choices = data.get("choices", [])
|
| 56 |
+
if choices:
|
| 57 |
+
c = choices[0]
|
| 58 |
+
if "message" in c and isinstance(c["message"], dict) and "content" in c["message"]:
|
| 59 |
+
# some OpenRouter shapes put content directly
|
| 60 |
+
content = c["message"]["content"]
|
| 61 |
+
# content might be dict or string; handle both
|
| 62 |
+
if isinstance(content, dict) and "content" in content:
|
| 63 |
+
text = content["content"]
|
| 64 |
+
elif isinstance(content, str):
|
| 65 |
+
text = content
|
| 66 |
+
elif "text" in c:
|
| 67 |
+
text = c["text"]
|
| 68 |
return text or ""
|
| 69 |
|
| 70 |
def ask_model_for_summary(text: str) -> str:
|
| 71 |
prompt = f"Summarize the following text clearly and concisely (bullet points, 5-8 bullets max):\n\n{text}"
|
| 72 |
+
messages = [{"role": "system", "content": "You are a concise summarizer."}, {"role": "user", "content": prompt}]
|
|
|
|
|
|
|
|
|
|
| 73 |
return openrouter_chat(messages, max_tokens=400)
|
| 74 |
|
| 75 |
def ask_model_question(question: str, context: str) -> str:
|
| 76 |
prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer in a concise helpful way."
|
| 77 |
+
messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
|
|
|
|
|
|
|
|
|
|
| 78 |
return openrouter_chat(messages, max_tokens=600)
|
| 79 |
|
| 80 |
# ---------- ElevenLabs TTS ----------
|
|
|
|
| 91 |
st.warning(f"ElevenLabs TTS failed: {r.status_code} {r.text[:300]}")
|
| 92 |
return None
|
| 93 |
|
| 94 |
+
# ---------- Hugging Face mermaid (optional) ----------
|
| 95 |
def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
if not HUGGINGFACE_KEY:
|
| 97 |
return None
|
| 98 |
url = f"https://api-inference.huggingface.co/models/{model}"
|
|
|
|
| 103 |
st.warning(f"Hugging Face model call failed: {r.status_code} {r.text[:300]}")
|
| 104 |
return None
|
| 105 |
j = r.json()
|
|
|
|
| 106 |
if isinstance(j, dict) and "error" in j:
|
| 107 |
st.warning(f"Hugging Face error: {j['error']}")
|
| 108 |
return None
|
| 109 |
if isinstance(j, list) and len(j) > 0 and isinstance(j[0], dict) and "generated_text" in j[0]:
|
| 110 |
return j[0]["generated_text"]
|
|
|
|
| 111 |
if isinstance(j, str):
|
| 112 |
return j
|
|
|
|
| 113 |
if isinstance(j, dict):
|
| 114 |
for k in ("generated_text", "output", "text"):
|
| 115 |
if k in j:
|
|
|
|
| 117 |
return None
|
| 118 |
|
| 119 |
def generate_mermaid_from_summary(summary: str) -> str:
|
| 120 |
+
prompt = ("Given the following concise summary, produce a Mermaid flowchart (use 'graph TD' or 'flowchart TD' syntax). "
|
| 121 |
+
"Output only the Mermaid code block (no extra explanation). Summary:\n\n" + summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
hf_output = call_hf_mermaid(prompt)
|
| 123 |
if hf_output:
|
|
|
|
|
|
|
| 124 |
m = re.search(r"```(?:mermaid)?\n([\s\S]+?)```", hf_output, re.IGNORECASE)
|
| 125 |
if m:
|
| 126 |
return m.group(1).strip()
|
| 127 |
return hf_output.strip()
|
| 128 |
|
| 129 |
+
# fallback: create simple sequential flowchart
|
|
|
|
| 130 |
lines = re.split(r"\n+|-{1,}\s*|•\s*", summary)
|
| 131 |
nodes = [clean_text(l) for l in lines if clean_text(l)]
|
|
|
|
| 132 |
nodes = nodes[:8]
|
| 133 |
if not nodes:
|
| 134 |
nodes = ["Summary empty"]
|
| 135 |
mermaid = "flowchart TD\n"
|
|
|
|
| 136 |
for i, n in enumerate(nodes):
|
| 137 |
+
# sanitize node text: replace double quotes with single quotes to avoid breaking mermaid quotes
|
| 138 |
+
node_text = n.replace('"', "'")[:80]
|
| 139 |
+
# use .format to avoid backslashes inside f-string expressions
|
| 140 |
+
mermaid += ' A{idx}["{text}"]\n'.format(idx=i, text=node_text)
|
| 141 |
for i in range(len(nodes) - 1):
|
| 142 |
mermaid += f" A{i} --> A{i+1}\n"
|
| 143 |
return mermaid
|
| 144 |
|
|
|
|
| 145 |
def render_mermaid(mermaid_code: str, height: int = 400):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
html = f"""
|
| 147 |
<div id="mermaid-target">
|
| 148 |
<pre class="mermaid">
|
|
|
|
| 154 |
mermaid.initialize({{startOnLoad:true}});
|
| 155 |
</script>
|
| 156 |
"""
|
| 157 |
+
# render
|
| 158 |
st.components.v1.html(html, height=height, scrolling=True)
|
| 159 |
|
| 160 |
# ============ STREAMLIT UI ============
|
| 161 |
st.set_page_config(page_title="PDF Q&A + Summary Diagram", layout="wide")
|
| 162 |
st.title("📄 PDF Q&A + Summary Diagram + Audio")
|
| 163 |
|
|
|
|
| 164 |
c1, c2, c3 = st.columns(3)
|
| 165 |
with c1:
|
| 166 |
if OPENROUTER_KEY:
|
|
|
|
| 194 |
if st.button("Summarize and generate diagram"):
|
| 195 |
try:
|
| 196 |
with st.spinner("Summarizing with OpenRouter..."):
|
|
|
|
| 197 |
to_sum = raw_text[:15000]
|
| 198 |
summary = ask_model_for_summary(to_sum)
|
| 199 |
st.subheader("📌 Summary")
|
| 200 |
st.write(summary)
|
| 201 |
|
|
|
|
| 202 |
audio = text_to_speech_eleven(summary)
|
| 203 |
if audio:
|
| 204 |
st.audio(audio, format="audio/mp3")
|
| 205 |
elif not ELEVEN_API_KEY:
|
| 206 |
st.info("TTS not available (ELEVEN_API_KEY missing).")
|
| 207 |
|
|
|
|
| 208 |
mermaid_code = generate_mermaid_from_summary(summary)
|
| 209 |
st.subheader("🗺️ Summary Diagram (Mermaid)")
|
| 210 |
render_mermaid(mermaid_code, height=480)
|
|
|
|
| 211 |
st.markdown("**Mermaid code (copy/paste):**")
|
| 212 |
st.code(mermaid_code, language="mermaid")
|
| 213 |
|
| 214 |
except Exception as e:
|
| 215 |
st.error(f"Summarize/diagram generation failed: {e}")
|
| 216 |
|
|
|
|
| 217 |
query = st.text_input("Ask a question about the PDF (use Enter):")
|
| 218 |
if query:
|
| 219 |
if not OPENROUTER_KEY:
|
|
|
|
| 223 |
with st.spinner("Answering via OpenRouter..."):
|
| 224 |
chunks = chunk_text_by_chars(raw_text, chunk_size=3000, overlap=200)
|
| 225 |
answers = []
|
| 226 |
+
for c in chunks[:3]:
|
| 227 |
a = ask_model_question(query, c)
|
| 228 |
if a:
|
| 229 |
answers.append(a)
|