tejovanth commited on
Commit
25abcb8
Β·
verified Β·
1 Parent(s): 5934ae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -156
app.py CHANGED
@@ -1,187 +1,91 @@
1
  import gradio as gr
2
- import fitz # PyMuPDF
3
  import torch
4
  from transformers import pipeline
5
- import time, logging, re
6
- import matplotlib
7
- matplotlib.use('Agg')
8
- import matplotlib.pyplot as plt
9
- import io
10
  from PIL import Image
 
 
11
  from concurrent.futures import ThreadPoolExecutor
12
- import numpy as np
13
 
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
- # Set device and optimize for speed
18
  device = 0 if torch.cuda.is_available() else -1
19
- logger.info(f"πŸ”§ Using {'GPU' if device == 0 else 'CPU'}")
20
-
21
- # Load a lighter model for faster inference
22
- try:
23
- summarizer = pipeline(
24
- "summarization",
25
- model="t5-small", # Lightweight model
26
- device=device,
27
- framework="pt",
28
- torch_dtype=torch.float16 if device == 0 else torch.float32 # Half-precision on GPU
29
- )
30
- except Exception as e:
31
- logger.error(f"❌ Model loading failed: {str(e)}")
32
- exit(1)
33
-
34
- def visualize_chunk_status(chunk_data):
35
- status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'}
36
- labels = [f"C{i['chunk']}" for i in chunk_data]
37
- colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
38
- times = [i.get('time', 0.1) for i in chunk_data]
39
-
40
- fig, ax = plt.subplots(figsize=(8, 2)) # Smaller figure size
41
- ax.barh(labels, times, color=colors, height=0.4) # Reduced bar height
42
- ax.set_xlabel("Time (s)")
43
- ax.set_title("Chunk Status")
44
- plt.tight_layout(pad=0.5) # Minimal padding
45
- buf = io.BytesIO()
46
- plt.savefig(buf, format='png', dpi=100) # Lower DPI for speed
47
- plt.close(fig)
48
- buf.seek(0)
49
- return Image.open(buf)
50
 
51
  def create_summary_flowchart(summaries):
52
- filtered = [s for s in summaries if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s]
53
  if not filtered:
54
- return None
55
 
56
- fig_height = max(1.5, len(filtered) * 0.5) # Reduced height scaling
57
- fig, ax = plt.subplots(figsize=(5, fig_height))
58
  ax.axis('off')
59
 
60
- ypos = np.arange(len(filtered) * 1.5, 0, -1.5) # Tighter spacing
61
- for i, (y, summary) in enumerate(zip(ypos, filtered)):
62
- summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip()[:100] + ("..." if len(summary.split("**Chunk")[1].split("\n", 1)[-1].strip()) > 100 else "")
63
- ax.text(0.5, y, summary_text, ha='center', va='center', fontsize=8, bbox=dict(facecolor="lightblue", edgecolor="black", pad=0.2))
 
 
 
64
  if i < len(filtered) - 1:
65
- ax.arrow(0.5, y - 0.2, 0, -1.1, head_width=0.02, head_length=0.1, fc='black', ec='black')
 
66
 
67
- plt.tight_layout(pad=0.1)
68
  buf = io.BytesIO()
69
- fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
70
  plt.close(fig)
71
  buf.seek(0)
72
- return Image.open(buf)
73
 
74
- def process_chunk(i, chunk):
75
- chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
76
- start_time = time.time()
77
-
78
- if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
79
- result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)"
80
- chunk_result['status'] = 'skipped'
81
- else:
82
- try:
83
- summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
84
- result = f"**Chunk {i+1}**:\n{summary}"
85
- chunk_result['status'] = 'summarized'
86
- except Exception as e:
87
- result = f"**Chunk {i+1}**: Error: {str(e)}"
88
- chunk_result['status'] = 'error'
89
-
90
- chunk_result['time'] = time.time() - start_time
91
- return result, chunk_result
92
-
93
- def summarize_file(file_bytes):
94
- start = time.time()
95
- summaries = []
96
- chunk_info = []
97
-
98
- # Stream text extraction
99
  try:
100
  doc = fitz.open(stream=file_bytes, filetype="pdf")
101
- text = ""
102
- for page in doc:
103
- page_text = page.get_text("text")
104
- if not page_text.strip():
105
- continue
106
- text += page_text
107
- if len(text) > 30000: # Early cutoff
108
- text = text[:30000]
109
- break
110
- doc.close()
111
-
112
- # Fast text cleaning
113
- text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text)
114
- text = "".join(c for c in text if ord(c) < 128)[:30000]
115
  except Exception as e:
116
- return f"Text extraction failed: {str(e)}", None, None
117
-
118
- if not text.strip():
119
- return "No text found", None, None
120
-
121
- # Smaller, sentence-aware chunks
122
- chunks = []
123
- current_chunk = ""
124
- for sentence in re.split(r'(?<=[.!?])\s+', text):
125
- if len(current_chunk) + len(sentence) <= 1000:
126
- current_chunk += sentence
127
- else:
128
- if current_chunk:
129
- chunks.append(current_chunk)
130
- current_chunk = sentence
131
- if len(chunks) >= 30: # Limit chunks
132
- break
133
- if current_chunk:
134
- chunks.append(current_chunk)
135
-
136
- # Dynamic worker count based on CPU/GPU
137
- max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4))
138
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
139
- results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
140
-
141
- for summary, info in results:
142
- summaries.append(summary)
143
- chunk_info.append(info)
144
-
145
- final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
146
- process_img = visualize_chunk_status(chunk_info)
147
- flow_img = create_summary_flowchart(summaries)
148
- return final_summary, process_img, flow_img
149
 
150
  demo = gr.Interface(
151
- fn=summarize_file,
152
- inputs=gr.File(label="Upload PDF", type="binary"),
153
  outputs=[
154
- gr.Textbox(label="Summary", lines=15),
155
- gr.Image(label="Chunk Status", type="pil"),
156
- gr.Image(label="Flow Summary", type="pil")
157
  ],
158
- title="PDF Summarizer",
159
- description="Summarizes PDFs up to 30,000 characters with visualizations."
160
  )
161
 
162
  if __name__ == "__main__":
163
- try:
164
- logger.info("Starting Gradio on http://127.0.0.1:7860")
165
- demo.launch(
166
- share=False,
167
- server_name="127.0.0.1",
168
- server_port=7860,
169
- debug=False # Disable debug for speed
170
- )
171
- except Exception as e:
172
- logger.error(f"Failed on port 7860: {str(e)}")
173
- logger.info("Trying port 7861...")
174
- try:
175
- demo.launch(
176
- share=False,
177
- server_name="127.0.0.1",
178
- server_port=7861,
179
- debug=False
180
- )
181
- except Exception as e2:
182
- logger.error(f"Failed on port 7861: {str(e2)}")
183
- raise
184
-
185
-
186
 
187
 
 
1
  import gradio as gr
2
+ import fitz
3
  import torch
4
  from transformers import pipeline
5
+ import re, time, io, uuid, os
 
 
 
 
6
  from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ matplotlib.use('Agg')
9
  from concurrent.futures import ThreadPoolExecutor
 
10
 
 
 
 
 
11
  device = 0 if torch.cuda.is_available() else -1
12
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)
13
+
14
+ def process_chunk(i, chunk):
15
+ if sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
16
+ return None
17
+ try:
18
+ summary = summarizer(chunk, max_length=80, min_length=15, do_sample=False)[0]['summary_text']
19
+ return summary
20
+ except:
21
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def create_summary_flowchart(summaries):
24
+ filtered = [s for s in summaries if s]
25
  if not filtered:
26
+ return None, None
27
 
28
+ fig_height = max(2, len(filtered) * 0.8 + 1)
29
+ fig, ax = plt.subplots(figsize=(6, fig_height))
30
  ax.axis('off')
31
 
32
+ ypos = list(range(len(filtered) * 2, 0, -2))
33
+ boxprops = dict(boxstyle="round,pad=0.5", facecolor="lightblue", edgecolor="black")
34
+
35
+ for i, (y, summary_text) in enumerate(zip(ypos, filtered)):
36
+ if len(summary_text) > 120:
37
+ summary_text = summary_text[:120] + "..."
38
+ ax.text(0.5, y, summary_text, ha='center', va='center', bbox=boxprops, fontsize=9)
39
  if i < len(filtered) - 1:
40
+ ax.annotate('', xy=(0.5, y - 1.2), xytext=(0.5, y - 0.3),
41
+ arrowprops=dict(arrowstyle="->", lw=1.5))
42
 
43
+ # Save to in-memory buffer (for display)
44
  buf = io.BytesIO()
45
+ fig.savefig(buf, format='png', bbox_inches='tight')
46
  plt.close(fig)
47
  buf.seek(0)
48
+ image = Image.open(buf)
49
 
50
+ # Save to disk (for download)
51
+ filename = f"/tmp/flowchart_{uuid.uuid4().hex}.png"
52
+ image.save(filename)
53
+
54
+ return image, filename
55
+
56
+ def generate_flowchart(file_bytes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
  doc = fitz.open(stream=file_bytes, filetype="pdf")
59
+ text = "".join(page.get_text("text") for page in doc)
60
+ text = re.sub(r"\$\s*([^$]+)\s*\$", r"\1", text)
61
+ text = re.sub(r"\\cap", "intersection", text)
62
+ text = re.sub(r"\s+", " ", text).strip()
63
+ text = "".join(c for c in text if ord(c) < 128)
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
+ return f"❌ Error reading PDF: {str(e)}", None
66
+
67
+ chunks = [text[i:i+1500] for i in range(0, min(len(text), 30000), 1500)]
68
+
69
+ with ThreadPoolExecutor(max_workers=4) as executor:
70
+ summaries = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
71
+
72
+ img, filepath = create_summary_flowchart(summaries)
73
+ if img is None:
74
+ return "❌ No valid summaries to display.", None
75
+ return img, filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  demo = gr.Interface(
78
+ fn=generate_flowchart,
79
+ inputs=gr.File(label="πŸ“„ Upload PDF", type="binary"),
80
  outputs=[
81
+ gr.Image(label="πŸ” Flowchart", type="pil"),
82
+ gr.File(label="πŸ“₯ Download Flowchart (PNG)")
 
83
  ],
84
+ title="πŸ“˜ Summary Flowchart Generator",
85
+ description="Uploads a PDF, generates summary chunks, and provides a downloadable visual flowchart."
86
  )
87
 
88
  if __name__ == "__main__":
89
+ demo.launch(share=False, server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91