tejovanth commited on
Commit
5934ae3
·
verified ·
1 Parent(s): 97db6a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -55
app.py CHANGED
@@ -9,17 +9,24 @@ import matplotlib.pyplot as plt
9
  import io
10
  from PIL import Image
11
  from concurrent.futures import ThreadPoolExecutor
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- # Set device (CPU or GPU)
17
  device = 0 if torch.cuda.is_available() else -1
18
  logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}")
19
 
20
- # Load model
21
  try:
22
- summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)
 
 
 
 
 
 
23
  except Exception as e:
24
  logger.error(f"❌ Model loading failed: {str(e)}")
25
  exit(1)
@@ -30,46 +37,36 @@ def visualize_chunk_status(chunk_data):
30
  colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
31
  times = [i.get('time', 0.1) for i in chunk_data]
32
 
33
- fig, ax = plt.subplots(figsize=(10, 2.5))
34
- ax.barh(labels, times, color=colors)
35
  ax.set_xlabel("Time (s)")
36
- ax.set_title("📊 Chunk Processing Status")
37
- plt.tight_layout()
38
  buf = io.BytesIO()
39
- plt.savefig(buf, format='png')
40
  plt.close(fig)
41
  buf.seek(0)
42
  return Image.open(buf)
43
 
44
  def create_summary_flowchart(summaries):
45
- filtered = [
46
- s for s in summaries
47
- if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s
48
- ]
49
  if not filtered:
50
  return None
51
 
52
- fig_height = max(2, len(filtered) * 0.8 + 1)
53
- fig, ax = plt.subplots(figsize=(6, fig_height))
54
  ax.axis('off')
55
 
56
- ypos = list(range(len(filtered) * 2, 0, -2))
57
- boxprops = dict(boxstyle="round,pad=0.5", facecolor="lightblue", edgecolor="black")
58
-
59
  for i, (y, summary) in enumerate(zip(ypos, filtered)):
60
- summary_text = summary.split("**Chunk")[1]
61
- summary_text = summary_text.replace("**:", ":").split("\n", 1)[-1].strip()
62
- if len(summary_text) > 120:
63
- summary_text = summary_text[:120] + "..."
64
- ax.text(0.5, y, summary_text, ha='center', va='center', bbox=boxprops, fontsize=9)
65
-
66
  if i < len(filtered) - 1:
67
- ax.annotate('', xy=(0.5, y - 1.2), xytext=(0.5, y - 0.3),
68
- arrowprops=dict(arrowstyle="->", lw=1.5))
69
 
70
- plt.tight_layout()
71
  buf = io.BytesIO()
72
- fig.savefig(buf, format='png', bbox_inches='tight')
73
  plt.close(fig)
74
  buf.seek(0)
75
  return Image.open(buf)
@@ -78,16 +75,16 @@ def process_chunk(i, chunk):
78
  chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
79
  start_time = time.time()
80
 
81
- if sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
82
- result = f"**Chunk {i+1}**: Skipped (equation-heavy)"
83
  chunk_result['status'] = 'skipped'
84
  else:
85
  try:
86
- summary = summarizer(chunk, max_length=80, min_length=15, do_sample=False)[0]['summary_text']
87
  result = f"**Chunk {i+1}**:\n{summary}"
88
  chunk_result['status'] = 'summarized'
89
  except Exception as e:
90
- result = f"**Chunk {i+1}**: Error: {str(e)}"
91
  chunk_result['status'] = 'error'
92
 
93
  chunk_result['time'] = time.time() - start_time
@@ -98,66 +95,91 @@ def summarize_file(file_bytes):
98
  summaries = []
99
  chunk_info = []
100
 
 
101
  try:
102
  doc = fitz.open(stream=file_bytes, filetype="pdf")
103
- text = "".join(page.get_text("text") for page in doc)
104
- text = re.sub(r"\$\s*([^$]+)\s*\$", r"\1", text)
105
- text = re.sub(r"\\cap", "intersection", text)
106
- text = re.sub(r"\s+", " ", text).strip()
107
- text = "".join(c for c in text if ord(c) < 128)
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- return f"Text extraction failed: {str(e)}", None, None
110
 
111
  if not text.strip():
112
- return "No text found", None, None
113
-
114
- chunks = [text[i:i+1500] for i in range(0, min(len(text), 30000), 1500)]
115
-
116
- with ThreadPoolExecutor(max_workers=4) as executor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
118
 
119
  for summary, info in results:
120
  summaries.append(summary)
121
  chunk_info.append(info)
122
 
123
- final_summary = f"**Processed chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
124
  process_img = visualize_chunk_status(chunk_info)
125
  flow_img = create_summary_flowchart(summaries)
126
  return final_summary, process_img, flow_img
127
 
128
  demo = gr.Interface(
129
  fn=summarize_file,
130
- inputs=gr.File(label="📄 Upload PDF", type="binary"),
131
  outputs=[
132
- gr.Textbox(label="📝 Summary", lines=20),
133
- gr.Image(label="📊 Chunk Status", type="pil"),
134
- gr.Image(label="🔁 Flow Summary", type="pil")
135
  ],
136
- title="📘 PDF Summarizer with Visual Flow",
137
- description="Summarizes up to 30,000 characters from a PDF. Includes chunk status and flowchart visualizations."
138
  )
139
 
140
  if __name__ == "__main__":
141
  try:
142
- logger.info("Starting Gradio application on http://127.0.0.1:7860")
143
  demo.launch(
144
  share=False,
145
  server_name="127.0.0.1",
146
  server_port=7860,
147
- debug=True
148
  )
149
  except Exception as e:
150
- logger.error(f"Gradio launch failed on port 7860: {str(e)}")
151
- logger.info("Trying alternative port 7861...")
152
  try:
153
  demo.launch(
154
  share=False,
155
  server_name="127.0.0.1",
156
  server_port=7861,
157
- debug=True
158
  )
159
  except Exception as e2:
160
- logger.error(f"Gradio launch failed on port 7861: {str(e2)}")
161
  raise
162
 
163
 
 
9
  import io
10
  from PIL import Image
11
  from concurrent.futures import ThreadPoolExecutor
12
+ import numpy as np
13
 
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Set device and optimize for speed
18
  device = 0 if torch.cuda.is_available() else -1
19
  logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}")
20
 
21
+ # Load a lighter model for faster inference
22
  try:
23
+ summarizer = pipeline(
24
+ "summarization",
25
+ model="t5-small", # Lightweight model
26
+ device=device,
27
+ framework="pt",
28
+ torch_dtype=torch.float16 if device == 0 else torch.float32 # Half-precision on GPU
29
+ )
30
  except Exception as e:
31
  logger.error(f"❌ Model loading failed: {str(e)}")
32
  exit(1)
 
37
  colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
38
  times = [i.get('time', 0.1) for i in chunk_data]
39
 
40
+ fig, ax = plt.subplots(figsize=(8, 2)) # Smaller figure size
41
+ ax.barh(labels, times, color=colors, height=0.4) # Reduced bar height
42
  ax.set_xlabel("Time (s)")
43
+ ax.set_title("Chunk Status")
44
+ plt.tight_layout(pad=0.5) # Minimal padding
45
  buf = io.BytesIO()
46
+ plt.savefig(buf, format='png', dpi=100) # Lower DPI for speed
47
  plt.close(fig)
48
  buf.seek(0)
49
  return Image.open(buf)
50
 
51
  def create_summary_flowchart(summaries):
52
+ filtered = [s for s in summaries if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s]
 
 
 
53
  if not filtered:
54
  return None
55
 
56
+ fig_height = max(1.5, len(filtered) * 0.5) # Reduced height scaling
57
+ fig, ax = plt.subplots(figsize=(5, fig_height))
58
  ax.axis('off')
59
 
60
+ ypos = np.arange(len(filtered) * 1.5, 0, -1.5) # Tighter spacing
 
 
61
  for i, (y, summary) in enumerate(zip(ypos, filtered)):
62
+ summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip()[:100] + ("..." if len(summary.split("**Chunk")[1].split("\n", 1)[-1].strip()) > 100 else "")
63
+ ax.text(0.5, y, summary_text, ha='center', va='center', fontsize=8, bbox=dict(facecolor="lightblue", edgecolor="black", pad=0.2))
 
 
 
 
64
  if i < len(filtered) - 1:
65
+ ax.arrow(0.5, y - 0.2, 0, -1.1, head_width=0.02, head_length=0.1, fc='black', ec='black')
 
66
 
67
+ plt.tight_layout(pad=0.1)
68
  buf = io.BytesIO()
69
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
70
  plt.close(fig)
71
  buf.seek(0)
72
  return Image.open(buf)
 
75
  chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
76
  start_time = time.time()
77
 
78
+ if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
79
+ result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)"
80
  chunk_result['status'] = 'skipped'
81
  else:
82
  try:
83
+ summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
84
  result = f"**Chunk {i+1}**:\n{summary}"
85
  chunk_result['status'] = 'summarized'
86
  except Exception as e:
87
+ result = f"**Chunk {i+1}**: Error: {str(e)}"
88
  chunk_result['status'] = 'error'
89
 
90
  chunk_result['time'] = time.time() - start_time
 
95
  summaries = []
96
  chunk_info = []
97
 
98
+ # Stream text extraction
99
  try:
100
  doc = fitz.open(stream=file_bytes, filetype="pdf")
101
+ text = ""
102
+ for page in doc:
103
+ page_text = page.get_text("text")
104
+ if not page_text.strip():
105
+ continue
106
+ text += page_text
107
+ if len(text) > 30000: # Early cutoff
108
+ text = text[:30000]
109
+ break
110
+ doc.close()
111
+
112
+ # Fast text cleaning
113
+ text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text)
114
+ text = "".join(c for c in text if ord(c) < 128)[:30000]
115
  except Exception as e:
116
+ return f"Text extraction failed: {str(e)}", None, None
117
 
118
  if not text.strip():
119
+ return "No text found", None, None
120
+
121
+ # Smaller, sentence-aware chunks
122
+ chunks = []
123
+ current_chunk = ""
124
+ for sentence in re.split(r'(?<=[.!?])\s+', text):
125
+ if len(current_chunk) + len(sentence) <= 1000:
126
+ current_chunk += sentence
127
+ else:
128
+ if current_chunk:
129
+ chunks.append(current_chunk)
130
+ current_chunk = sentence
131
+ if len(chunks) >= 30: # Limit chunks
132
+ break
133
+ if current_chunk:
134
+ chunks.append(current_chunk)
135
+
136
+ # Dynamic worker count based on CPU/GPU
137
+ max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4))
138
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
139
  results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
140
 
141
  for summary, info in results:
142
  summaries.append(summary)
143
  chunk_info.append(info)
144
 
145
+ final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
146
  process_img = visualize_chunk_status(chunk_info)
147
  flow_img = create_summary_flowchart(summaries)
148
  return final_summary, process_img, flow_img
149
 
150
  demo = gr.Interface(
151
  fn=summarize_file,
152
+ inputs=gr.File(label="Upload PDF", type="binary"),
153
  outputs=[
154
+ gr.Textbox(label="Summary", lines=15),
155
+ gr.Image(label="Chunk Status", type="pil"),
156
+ gr.Image(label="Flow Summary", type="pil")
157
  ],
158
+ title="PDF Summarizer",
159
+ description="Summarizes PDFs up to 30,000 characters with visualizations."
160
  )
161
 
162
  if __name__ == "__main__":
163
  try:
164
+ logger.info("Starting Gradio on http://127.0.0.1:7860")
165
  demo.launch(
166
  share=False,
167
  server_name="127.0.0.1",
168
  server_port=7860,
169
+ debug=False # Disable debug for speed
170
  )
171
  except Exception as e:
172
+ logger.error(f"Failed on port 7860: {str(e)}")
173
+ logger.info("Trying port 7861...")
174
  try:
175
  demo.launch(
176
  share=False,
177
  server_name="127.0.0.1",
178
  server_port=7861,
179
+ debug=False
180
  )
181
  except Exception as e2:
182
+ logger.error(f"Failed on port 7861: {str(e2)}")
183
  raise
184
 
185