ChatBotsTA commited on
Commit
4e85813
·
verified ·
1 Parent(s): a65d0ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -73
app.py CHANGED
@@ -9,10 +9,10 @@ from typing import List, Optional
9
 
10
  # ============ CONFIG =============
11
  OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY")
12
- OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "gpt-4o-mini") # change if you prefer
13
  ELEVEN_API_KEY = os.getenv("ELEVEN_API_KEY")
14
- HUGGINGFACE_KEY = os.getenv("HUGGINGFACE_API_KEY") # optional: if set, we'll call a HF mermaid model
15
- HF_MERMAID_MODEL = os.getenv("HF_MERMAID_MODEL", "TroyDoesAI/MermaidStable3B") # example community model
16
 
17
  # ============ HELPERS ============
18
  def clean_text(text: str) -> str:
@@ -38,59 +38,43 @@ def chunk_text_by_chars(text: str, chunk_size: int = 3000, overlap: int = 200) -
38
  start = max(end - overlap, end)
39
  return chunks
40
 
41
- # ---------- OpenRouter chat (replacement for openai.ChatCompletion) ----------
42
  def openrouter_chat(messages: List[dict], model: str = OPENROUTER_MODEL, max_tokens: int = 800, temperature: float = 0.2) -> str:
43
- """
44
- Send messages (OpenAI-style) to OpenRouter's chat completions endpoint.
45
- Requires OPENROUTER_API_KEY in ENV.
46
- """
47
  if not OPENROUTER_KEY:
48
  raise RuntimeError("OPENROUTER_API_KEY not set")
49
-
50
  url = "https://api.openrouter.ai/v1/chat/completions"
51
  headers = {"Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json"}
52
- payload = {
53
- "model": model,
54
- "messages": messages,
55
- "max_tokens": max_tokens,
56
- "temperature": temperature,
57
- }
58
  resp = requests.post(url, json=payload, headers=headers, timeout=60)
59
  try:
60
  resp.raise_for_status()
61
  except Exception as e:
62
  raise RuntimeError(f"OpenRouter API error: {resp.status_code} {resp.text}") from e
63
-
64
  data = resp.json()
65
- # robustly extract text
66
  text = ""
67
- try:
68
- choices = data.get("choices", [])
69
- if choices:
70
- c = choices[0]
71
- # OpenRouter returns similar shape to OpenAI
72
- if "message" in c and "content" in c["message"]:
73
- text = c["message"]["content"]
74
- elif "text" in c:
75
- text = c["text"]
76
- except Exception:
77
- text = ""
 
 
78
  return text or ""
79
 
80
  def ask_model_for_summary(text: str) -> str:
81
  prompt = f"Summarize the following text clearly and concisely (bullet points, 5-8 bullets max):\n\n{text}"
82
- messages = [
83
- {"role": "system", "content": "You are a concise summarizer."},
84
- {"role": "user", "content": prompt},
85
- ]
86
  return openrouter_chat(messages, max_tokens=400)
87
 
88
  def ask_model_question(question: str, context: str) -> str:
89
  prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer in a concise helpful way."
90
- messages = [
91
- {"role": "system", "content": "You are a helpful assistant."},
92
- {"role": "user", "content": prompt},
93
- ]
94
  return openrouter_chat(messages, max_tokens=600)
95
 
96
  # ---------- ElevenLabs TTS ----------
@@ -107,12 +91,8 @@ def text_to_speech_eleven(text: str, voice_id: str = "pnCWbS8Aqipqqr5wzjuy") ->
107
  st.warning(f"ElevenLabs TTS failed: {r.status_code} {r.text[:300]}")
108
  return None
109
 
110
- # ---------- Mermaid generation (Hugging Face model optional) ----------
111
  def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]:
112
- """
113
- If HUGGINGFACE_KEY is set, call Hugging Face Inference API for model that outputs Mermaid or Mermaid-like code.
114
- Many community models/Spaces are simple text-output LLMs that can be prompted to return mermaid code.
115
- """
116
  if not HUGGINGFACE_KEY:
117
  return None
118
  url = f"https://api-inference.huggingface.co/models/{model}"
@@ -123,16 +103,13 @@ def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]
123
  st.warning(f"Hugging Face model call failed: {r.status_code} {r.text[:300]}")
124
  return None
125
  j = r.json()
126
- # shape varies by model; try to extract text
127
  if isinstance(j, dict) and "error" in j:
128
  st.warning(f"Hugging Face error: {j['error']}")
129
  return None
130
  if isinstance(j, list) and len(j) > 0 and isinstance(j[0], dict) and "generated_text" in j[0]:
131
  return j[0]["generated_text"]
132
- # some models return plain text in str
133
  if isinstance(j, str):
134
  return j
135
- # fallback: try to get 'output' key
136
  if isinstance(j, dict):
137
  for k in ("generated_text", "output", "text"):
138
  if k in j:
@@ -140,47 +117,32 @@ def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]
140
  return None
141
 
142
  def generate_mermaid_from_summary(summary: str) -> str:
143
- """
144
- Try HF model first (if key provided). If not available or fails, produce a clean Mermaid flowchart locally.
145
- We'll create a simple flow: split summary into sentences / bullets and link them sequentially.
146
- """
147
- # first try HF
148
- prompt = (
149
- "Given the following concise summary, produce a Mermaid flowchart (use 'graph TD' or 'flowchart TD' syntax). "
150
- "Output only the Mermaid code block (no extra explanation). Summary:\n\n" + summary
151
- )
152
  hf_output = call_hf_mermaid(prompt)
153
  if hf_output:
154
- # try to extract just the mermaid text
155
- # if the model wrapped in ```mermaid ... ``` try to strip
156
  m = re.search(r"```(?:mermaid)?\n([\s\S]+?)```", hf_output, re.IGNORECASE)
157
  if m:
158
  return m.group(1).strip()
159
  return hf_output.strip()
160
 
161
- # fallback local generator
162
- # split by bullet/newline or sentences
163
  lines = re.split(r"\n+|-{1,}\s*|•\s*", summary)
164
  nodes = [clean_text(l) for l in lines if clean_text(l)]
165
- # keep a reasonable number
166
  nodes = nodes[:8]
167
  if not nodes:
168
  nodes = ["Summary empty"]
169
  mermaid = "flowchart TD\n"
170
- # create nodes with safe ids
171
  for i, n in enumerate(nodes):
172
- # short id
173
- mermaid += f' A{i}["{n.replace(\'"\', "\\\'")[:80]}"]\n'
 
 
174
  for i in range(len(nodes) - 1):
175
  mermaid += f" A{i} --> A{i+1}\n"
176
  return mermaid
177
 
178
- # ---------- Render mermaid in browser ----------
179
  def render_mermaid(mermaid_code: str, height: int = 400):
180
- """
181
- Render Mermaid chart client-side using mermaid.js in an HTML component.
182
- """
183
- # wrap in HTML that loads mermaid CDN
184
  html = f"""
185
  <div id="mermaid-target">
186
  <pre class="mermaid">
@@ -192,13 +154,13 @@ def render_mermaid(mermaid_code: str, height: int = 400):
192
  mermaid.initialize({{startOnLoad:true}});
193
  </script>
194
  """
 
195
  st.components.v1.html(html, height=height, scrolling=True)
196
 
197
  # ============ STREAMLIT UI ============
198
  st.set_page_config(page_title="PDF Q&A + Summary Diagram", layout="wide")
199
  st.title("📄 PDF Q&A + Summary Diagram + Audio")
200
 
201
- # API status
202
  c1, c2, c3 = st.columns(3)
203
  with c1:
204
  if OPENROUTER_KEY:
@@ -232,31 +194,26 @@ if uploaded_file:
232
  if st.button("Summarize and generate diagram"):
233
  try:
234
  with st.spinner("Summarizing with OpenRouter..."):
235
- # limit to avoid huge inputs
236
  to_sum = raw_text[:15000]
237
  summary = ask_model_for_summary(to_sum)
238
  st.subheader("📌 Summary")
239
  st.write(summary)
240
 
241
- # TTS summary
242
  audio = text_to_speech_eleven(summary)
243
  if audio:
244
  st.audio(audio, format="audio/mp3")
245
  elif not ELEVEN_API_KEY:
246
  st.info("TTS not available (ELEVEN_API_KEY missing).")
247
 
248
- # produce mermaid
249
  mermaid_code = generate_mermaid_from_summary(summary)
250
  st.subheader("🗺️ Summary Diagram (Mermaid)")
251
  render_mermaid(mermaid_code, height=480)
252
- # also show the raw mermaid code for copy/paste
253
  st.markdown("**Mermaid code (copy/paste):**")
254
  st.code(mermaid_code, language="mermaid")
255
 
256
  except Exception as e:
257
  st.error(f"Summarize/diagram generation failed: {e}")
258
 
259
- # Q&A box
260
  query = st.text_input("Ask a question about the PDF (use Enter):")
261
  if query:
262
  if not OPENROUTER_KEY:
@@ -266,7 +223,7 @@ if uploaded_file:
266
  with st.spinner("Answering via OpenRouter..."):
267
  chunks = chunk_text_by_chars(raw_text, chunk_size=3000, overlap=200)
268
  answers = []
269
- for c in chunks[:3]: # limit to 3 chunks
270
  a = ask_model_question(query, c)
271
  if a:
272
  answers.append(a)
 
9
 
10
  # ============ CONFIG =============
11
  OPENROUTER_KEY = os.getenv("OPENROUTER_API_KEY")
12
+ OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "gpt-4o-mini")
13
  ELEVEN_API_KEY = os.getenv("ELEVEN_API_KEY")
14
+ HUGGINGFACE_KEY = os.getenv("HUGGINGFACE_API_KEY")
15
+ HF_MERMAID_MODEL = os.getenv("HF_MERMAID_MODEL", "TroyDoesAI/MermaidStable3B")
16
 
17
  # ============ HELPERS ============
18
  def clean_text(text: str) -> str:
 
38
  start = max(end - overlap, end)
39
  return chunks
40
 
41
+ # ---------- OpenRouter chat ----------
42
  def openrouter_chat(messages: List[dict], model: str = OPENROUTER_MODEL, max_tokens: int = 800, temperature: float = 0.2) -> str:
 
 
 
 
43
  if not OPENROUTER_KEY:
44
  raise RuntimeError("OPENROUTER_API_KEY not set")
 
45
  url = "https://api.openrouter.ai/v1/chat/completions"
46
  headers = {"Authorization": f"Bearer {OPENROUTER_KEY}", "Content-Type": "application/json"}
47
+ payload = {"model": model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature}
 
 
 
 
 
48
  resp = requests.post(url, json=payload, headers=headers, timeout=60)
49
  try:
50
  resp.raise_for_status()
51
  except Exception as e:
52
  raise RuntimeError(f"OpenRouter API error: {resp.status_code} {resp.text}") from e
 
53
  data = resp.json()
 
54
  text = ""
55
+ choices = data.get("choices", [])
56
+ if choices:
57
+ c = choices[0]
58
+ if "message" in c and isinstance(c["message"], dict) and "content" in c["message"]:
59
+ # some OpenRouter shapes put content directly
60
+ content = c["message"]["content"]
61
+ # content might be dict or string; handle both
62
+ if isinstance(content, dict) and "content" in content:
63
+ text = content["content"]
64
+ elif isinstance(content, str):
65
+ text = content
66
+ elif "text" in c:
67
+ text = c["text"]
68
  return text or ""
69
 
70
  def ask_model_for_summary(text: str) -> str:
71
  prompt = f"Summarize the following text clearly and concisely (bullet points, 5-8 bullets max):\n\n{text}"
72
+ messages = [{"role": "system", "content": "You are a concise summarizer."}, {"role": "user", "content": prompt}]
 
 
 
73
  return openrouter_chat(messages, max_tokens=400)
74
 
75
  def ask_model_question(question: str, context: str) -> str:
76
  prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer in a concise helpful way."
77
+ messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
 
 
 
78
  return openrouter_chat(messages, max_tokens=600)
79
 
80
  # ---------- ElevenLabs TTS ----------
 
91
  st.warning(f"ElevenLabs TTS failed: {r.status_code} {r.text[:300]}")
92
  return None
93
 
94
+ # ---------- Hugging Face mermaid (optional) ----------
95
  def call_hf_mermaid(prompt: str, model: str = HF_MERMAID_MODEL) -> Optional[str]:
 
 
 
 
96
  if not HUGGINGFACE_KEY:
97
  return None
98
  url = f"https://api-inference.huggingface.co/models/{model}"
 
103
  st.warning(f"Hugging Face model call failed: {r.status_code} {r.text[:300]}")
104
  return None
105
  j = r.json()
 
106
  if isinstance(j, dict) and "error" in j:
107
  st.warning(f"Hugging Face error: {j['error']}")
108
  return None
109
  if isinstance(j, list) and len(j) > 0 and isinstance(j[0], dict) and "generated_text" in j[0]:
110
  return j[0]["generated_text"]
 
111
  if isinstance(j, str):
112
  return j
 
113
  if isinstance(j, dict):
114
  for k in ("generated_text", "output", "text"):
115
  if k in j:
 
117
  return None
118
 
119
  def generate_mermaid_from_summary(summary: str) -> str:
120
+ prompt = ("Given the following concise summary, produce a Mermaid flowchart (use 'graph TD' or 'flowchart TD' syntax). "
121
+ "Output only the Mermaid code block (no extra explanation). Summary:\n\n" + summary)
 
 
 
 
 
 
 
122
  hf_output = call_hf_mermaid(prompt)
123
  if hf_output:
 
 
124
  m = re.search(r"```(?:mermaid)?\n([\s\S]+?)```", hf_output, re.IGNORECASE)
125
  if m:
126
  return m.group(1).strip()
127
  return hf_output.strip()
128
 
129
+ # fallback: create simple sequential flowchart
 
130
  lines = re.split(r"\n+|-{1,}\s*|•\s*", summary)
131
  nodes = [clean_text(l) for l in lines if clean_text(l)]
 
132
  nodes = nodes[:8]
133
  if not nodes:
134
  nodes = ["Summary empty"]
135
  mermaid = "flowchart TD\n"
 
136
  for i, n in enumerate(nodes):
137
+ # sanitize node text: replace double quotes with single quotes to avoid breaking mermaid quotes
138
+ node_text = n.replace('"', "'")[:80]
139
+ # use .format to avoid backslashes inside f-string expressions
140
+ mermaid += ' A{idx}["{text}"]\n'.format(idx=i, text=node_text)
141
  for i in range(len(nodes) - 1):
142
  mermaid += f" A{i} --> A{i+1}\n"
143
  return mermaid
144
 
 
145
  def render_mermaid(mermaid_code: str, height: int = 400):
 
 
 
 
146
  html = f"""
147
  <div id="mermaid-target">
148
  <pre class="mermaid">
 
154
  mermaid.initialize({{startOnLoad:true}});
155
  </script>
156
  """
157
+ # render
158
  st.components.v1.html(html, height=height, scrolling=True)
159
 
160
  # ============ STREAMLIT UI ============
161
  st.set_page_config(page_title="PDF Q&A + Summary Diagram", layout="wide")
162
  st.title("📄 PDF Q&A + Summary Diagram + Audio")
163
 
 
164
  c1, c2, c3 = st.columns(3)
165
  with c1:
166
  if OPENROUTER_KEY:
 
194
  if st.button("Summarize and generate diagram"):
195
  try:
196
  with st.spinner("Summarizing with OpenRouter..."):
 
197
  to_sum = raw_text[:15000]
198
  summary = ask_model_for_summary(to_sum)
199
  st.subheader("📌 Summary")
200
  st.write(summary)
201
 
 
202
  audio = text_to_speech_eleven(summary)
203
  if audio:
204
  st.audio(audio, format="audio/mp3")
205
  elif not ELEVEN_API_KEY:
206
  st.info("TTS not available (ELEVEN_API_KEY missing).")
207
 
 
208
  mermaid_code = generate_mermaid_from_summary(summary)
209
  st.subheader("🗺️ Summary Diagram (Mermaid)")
210
  render_mermaid(mermaid_code, height=480)
 
211
  st.markdown("**Mermaid code (copy/paste):**")
212
  st.code(mermaid_code, language="mermaid")
213
 
214
  except Exception as e:
215
  st.error(f"Summarize/diagram generation failed: {e}")
216
 
 
217
  query = st.text_input("Ask a question about the PDF (use Enter):")
218
  if query:
219
  if not OPENROUTER_KEY:
 
223
  with st.spinner("Answering via OpenRouter..."):
224
  chunks = chunk_text_by_chars(raw_text, chunk_size=3000, overlap=200)
225
  answers = []
226
+ for c in chunks[:3]:
227
  a = ask_model_question(query, c)
228
  if a:
229
  answers.append(a)