tejovanth commited on
Commit
f57bdb5
·
verified ·
1 Parent(s): f45e94d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -58
app.py CHANGED
@@ -1,93 +1,202 @@
1
  import gradio as gr
2
- import fitz
3
  import torch
4
  from transformers import pipeline
5
- import re, time, io, uuid, os
6
- from PIL import Image
7
  import matplotlib
8
- import matplotlib.pyplot as plt
9
  matplotlib.use('Agg')
 
 
 
10
  from concurrent.futures import ThreadPoolExecutor
 
11
 
12
- device = 0 if torch.cuda.is_available() else -1
13
- summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)
14
 
15
- def process_chunk(i, chunk):
16
- if sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
17
- return None
18
- try:
19
- summary = summarizer(chunk, max_length=80, min_length=15, do_sample=False)[0]['summary_text']
20
- return summary
21
- except:
22
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def create_summary_flowchart(summaries):
25
- filtered = [s for s in summaries if s]
 
 
 
 
26
  if not filtered:
27
- return None, None
 
 
 
 
 
 
 
 
28
 
29
- fig_height = max(2, len(filtered) * 0.8 + 1)
 
30
  fig, ax = plt.subplots(figsize=(6, fig_height))
31
  ax.axis('off')
32
 
33
- ypos = list(range(len(filtered) * 2, 0, -2))
34
- boxprops = dict(boxstyle="round,pad=0.5", facecolor="lightblue", edgecolor="black")
 
35
 
36
- for i, (y, summary_text) in enumerate(zip(ypos, filtered)):
37
- if len(summary_text) > 120:
38
- summary_text = summary_text[:120] + "..."
39
- ax.text(0.5, y, summary_text, ha='center', va='center', bbox=boxprops, fontsize=9)
40
- if i < len(filtered) - 1:
41
- ax.annotate('', xy=(0.5, y - 1.2), xytext=(0.5, y - 0.3),
42
- arrowprops=dict(arrowstyle="->", lw=1.5))
43
 
44
- # Save to in-memory buffer (for display)
 
 
 
45
  buf = io.BytesIO()
46
- fig.savefig(buf, format='png', bbox_inches='tight')
47
  plt.close(fig)
48
  buf.seek(0)
49
- image = Image.open(buf)
50
-
51
- # Save to disk (for download)
52
- filename = f"/tmp/flowchart_{uuid.uuid4().hex}.png"
53
- image.save(filename)
54
 
55
- return image, filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- def generate_flowchart(file_bytes):
58
  try:
59
  doc = fitz.open(stream=file_bytes, filetype="pdf")
60
- text = "".join(page.get_text("text") for page in doc)
61
- text = re.sub(r"\$\s*([^$]+)\s*\$", r"\1", text)
62
- text = re.sub(r"\\cap", "intersection", text)
63
- text = re.sub(r"\s+", " ", text).strip()
64
- text = "".join(c for c in text if ord(c) < 128)
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
- return f" Error reading PDF: {str(e)}", None
67
-
68
- chunks = [text[i:i+1500] for i in range(0, min(len(text), 30000), 1500)]
69
-
70
- with ThreadPoolExecutor(max_workers=4) as executor:
71
- summaries = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
72
-
73
- img, filepath = create_summary_flowchart(summaries)
74
- if img is None:
75
- return "❌ No valid summaries to display.", None
76
- return img, filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  demo = gr.Interface(
79
- fn=generate_flowchart,
80
- inputs=gr.File(label="📄 Upload PDF", type="binary"),
81
  outputs=[
82
- gr.Image(label="🔁 Flowchart", type="pil"),
83
- gr.File(label="📥 Download Flowchart (PNG)")
 
84
  ],
85
- title="📘 Summary Flowchart Generator",
86
- description="Uploads a PDF, generates summary chunks, and provides a downloadable visual flowchart."
87
  )
88
 
89
  if __name__ == "__main__":
90
- demo.launch(share=False, server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
 
 
1
  import gradio as gr
2
+ import fitz # PyMuPDF
3
  import torch
4
  from transformers import pipeline
5
+ import time, logging, re
 
6
  import matplotlib
 
7
  matplotlib.use('Agg')
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ from PIL import Image
11
  from concurrent.futures import ThreadPoolExecutor
12
+ import numpy as np
13
 
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
 
17
+ # Set device and optimize for speed
18
+ device = 0 if torch.cuda.is_available() else -1
19
+ logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}")
20
+
21
+ # Load model
22
+ try:
23
+ summarizer = pipeline(
24
+ "summarization",
25
+ model="t5-small",
26
+ device=device,
27
+ framework="pt",
28
+ torch_dtype=torch.float16 if device == 0 else torch.float32
29
+ )
30
+ except Exception as e:
31
+ logger.error(f"❌ Model loading failed: {str(e)}")
32
+ exit(1)
33
+
34
+ def visualize_chunk_status(chunk_data):
35
+ status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'}
36
+ labels = [f"C{i['chunk']}" for i in chunk_data]
37
+ colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
38
+ times = [i.get('time', 0.1) for i in chunk_data]
39
+
40
+ fig, ax = plt.subplots(figsize=(8, 2))
41
+ ax.barh(labels, times, color=colors, height=0.4)
42
+ ax.set_xlabel("Time (s)")
43
+ ax.set_title("Chunk Status")
44
+ plt.tight_layout(pad=0.5)
45
+ buf = io.BytesIO()
46
+ plt.savefig(buf, format='png', dpi=100)
47
+ plt.close(fig)
48
+ buf.seek(0)
49
+ return Image.open(buf)
50
 
51
  def create_summary_flowchart(summaries):
52
+ # Filter valid summaries and extract key points
53
+ filtered = [
54
+ s for s in summaries
55
+ if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s
56
+ ]
57
  if not filtered:
58
+ return None
59
+
60
+ # Extract key points (first sentence or most important phrase)
61
+ key_points = []
62
+ for summary in filtered:
63
+ summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip()
64
+ # Take first sentence or truncate to 50 characters
65
+ first_sentence = re.split(r'(?<=[.!?])\s+', summary_text)[0][:50]
66
+ key_points.append(first_sentence + ("..." if len(first_sentence) >= 50 else ""))
67
 
68
+ # Create flowchart
69
+ fig_height = max(1.5, len(key_points) * 0.6)
70
  fig, ax = plt.subplots(figsize=(6, fig_height))
71
  ax.axis('off')
72
 
73
+ # Node positions and styling
74
+ ypos = np.arange(len(key_points) * 1.2, 0, -1.2)
75
+ boxprops = dict(boxstyle="round,pad=0.3", facecolor="lightgreen", edgecolor="black", alpha=0.9)
76
 
77
+ for i, (y, point) in enumerate(zip(ypos, key_points)):
78
+ # Draw node with key point
79
+ ax.text(0.5, y, point, ha='center', va='center', fontsize=9, bbox=boxprops)
80
+ # Draw arrows between nodes
81
+ if i < len(key_points) - 1:
82
+ ax.arrow(0.5, y - 0.3, 0, -0.9, head_width=0.02, head_length=0.1, fc='blue', ec='blue')
 
83
 
84
+ # Add title
85
+ ax.text(0.5, ypos[0] + 0.8, "Key Points Summary", ha='center', va='center', fontsize=12, weight='bold')
86
+
87
+ plt.tight_layout(pad=0.1)
88
  buf = io.BytesIO()
89
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
90
  plt.close(fig)
91
  buf.seek(0)
92
+ return Image.open(buf)
 
 
 
 
93
 
94
+ def process_chunk(i, chunk):
95
+ chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
96
+ start_time = time.time()
97
+
98
+ if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
99
+ result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)"
100
+ chunk_result['status'] = 'skipped'
101
+ else:
102
+ try:
103
+ summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
104
+ result = f"**Chunk {i+1}**:\n{summary}"
105
+ chunk_result['status'] = 'summarized'
106
+ except Exception as e:
107
+ result = f"**Chunk {i+1}**: Error: {str(e)}"
108
+ chunk_result['status'] = 'error'
109
+
110
+ chunk_result['time'] = time.time() - start_time
111
+ return result, chunk_result
112
+
113
+ def summarize_file(file_bytes):
114
+ start = time.time()
115
+ summaries = []
116
+ chunk_info = []
117
 
 
118
  try:
119
  doc = fitz.open(stream=file_bytes, filetype="pdf")
120
+ text = ""
121
+ for page in doc:
122
+ page_text = page.get_text("text")
123
+ if not page_text.strip():
124
+ continue
125
+ text += page_text
126
+ if len(text) > 30000:
127
+ text = text[:30000]
128
+ break
129
+ doc.close()
130
+
131
+ text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text)
132
+ text = "".join(c for c in text if ord(c) < 128)[:30000]
133
  except Exception as e:
134
+ return f"Text extraction failed: {str(e)}", None, None
135
+
136
+ if not text.strip():
137
+ return "No text found", None, None
138
+
139
+ chunks = []
140
+ current_chunk = ""
141
+ for sentence in re.split(r'(?<=[.!?])\s+', text):
142
+ if len(current_chunk) + len(sentence) <= 1000:
143
+ current_chunk += sentence
144
+ else:
145
+ if current_chunk:
146
+ chunks.append(current_chunk)
147
+ current_chunk = sentence
148
+ if len(chunks) >= 30:
149
+ break
150
+ if current_chunk:
151
+ chunks.append(current_chunk)
152
+
153
+ max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4))
154
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
155
+ results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
156
+
157
+ for summary, info in results:
158
+ summaries.append(summary)
159
+ chunk_info.append(info)
160
+
161
+ final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
162
+ process_img = visualize_chunk_status(chunk_info)
163
+ flow_img = create_summary_flowchart(summaries)
164
+ return final_summary, process_img, flow_img
165
 
166
  demo = gr.Interface(
167
+ fn=summarize_file,
168
+ inputs=gr.File(label="Upload PDF", type="binary"),
169
  outputs=[
170
+ gr.Textbox(label="Summary", lines=15),
171
+ gr.Image(label="Chunk Status", type="pil"),
172
+ gr.Image(label="Key Points Flowchart", type="pil")
173
  ],
174
+ title="PDF Summarizer",
175
+ description="Summarizes PDFs up to 30,000 characters with key point flowchart."
176
  )
177
 
178
  if __name__ == "__main__":
179
+ try:
180
+ logger.info("Starting Gradio on http://127.0.0.1:7860")
181
+ demo.launch(
182
+ share=False,
183
+ server_name="127.0.0.1",
184
+ server_port=7860,
185
+ debug=False
186
+ )
187
+ except Exception as e:
188
+ logger.error(f"Failed on port 7860: {str(e)}")
189
+ logger.info("Trying port 7861...")
190
+ try:
191
+ demo.launch(
192
+ share=False,
193
+ server_name="127.0.0.1",
194
+ server_port=7861,
195
+ debug=False
196
+ )
197
+ except Exception as e2:
198
+ logger.error(f"Failed on port 7861: {str(e2)}")
199
+ raise
200
 
201
 
202