tejovanth commited on
Commit
f9c46c3
·
verified ·
1 Parent(s): 5457157

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import fitz # PyMuPDF
3
+ import torch
4
+ from transformers import pipeline
5
+ import time, logging, re
6
+ import matplotlib
7
+ matplotlib.use('Agg')
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ from PIL import Image
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ import numpy as np
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Set device and optimize for speed
18
+ device = 0 if torch.cuda.is_available() else -1
19
+ logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}")
20
+
21
+ # Load model
22
+ try:
23
+ summarizer = pipeline(
24
+ "summarization",
25
+ model="t5-small",
26
+ device=device,
27
+ framework="pt",
28
+ torch_dtype=torch.float16 if device == 0 else torch.float32
29
+ )
30
+ except Exception as e:
31
+ logger.error(f"❌ Model loading failed: {str(e)}")
32
+ exit(1)
33
+
34
+ def visualize_chunk_status(chunk_data):
35
+ status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'}
36
+ labels = [f"C{i['chunk']}" for i in chunk_data]
37
+ colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
38
+ times = [i.get('time', 0.1) for i in chunk_data]
39
+
40
+ fig, ax = plt.subplots(figsize=(8, 2))
41
+ ax.barh(labels, times, color=colors, height=0.4)
42
+ ax.set_xlabel("Time (s)")
43
+ ax.set_title("Chunk Status")
44
+ plt.tight_layout(pad=0.5)
45
+ buf = io.BytesIO()
46
+ plt.savefig(buf, format='png', dpi=100)
47
+ plt.close(fig)
48
+ buf.seek(0)
49
+ return Image.open(buf)
50
+
51
+ def create_summary_flowchart(summaries):
52
+ # Filter valid summaries and extract key points
53
+ filtered = [
54
+ s for s in summaries
55
+ if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s
56
+ ]
57
+ if not filtered:
58
+ return None
59
+
60
+ # Extract key points (first sentence or most important phrase)
61
+ key_points = []
62
+ for summary in filtered:
63
+ summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip()
64
+ # Take first sentence or truncate to 50 characters
65
+ first_sentence = re.split(r'(?<=[.!?])\s+', summary_text)[0][:50]
66
+ key_points.append(first_sentence + ("..." if len(first_sentence) >= 50 else ""))
67
+
68
+ # Create flowchart
69
+ fig_height = max(1.5, len(key_points) * 0.6)
70
+ fig, ax = plt.subplots(figsize=(6, fig_height))
71
+ ax.axis('off')
72
+
73
+ # Node positions and styling
74
+ ypos = np.arange(len(key_points) * 1.2, 0, -1.2)
75
+ boxprops = dict(boxstyle="round,pad=0.3", facecolor="lightgreen", edgecolor="black", alpha=0.9)
76
+
77
+ for i, (y, point) in enumerate(zip(ypos, key_points)):
78
+ # Draw node with key point
79
+ ax.text(0.5, y, point, ha='center', va='center', fontsize=9, bbox=boxprops)
80
+ # Draw arrows between nodes
81
+ if i < len(key_points) - 1:
82
+ ax.arrow(0.5, y - 0.3, 0, -0.9, head_width=0.02, head_length=0.1, fc='blue', ec='blue')
83
+
84
+ # Add title
85
+ ax.text(0.5, ypos[0] + 0.8, "Key Points Summary", ha='center', va='center', fontsize=12, weight='bold')
86
+
87
+ plt.tight_layout(pad=0.1)
88
+ buf = io.BytesIO()
89
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
90
+ plt.close(fig)
91
+ buf.seek(0)
92
+ return Image.open(buf)
93
+
94
+ def process_chunk(i, chunk):
95
+ chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
96
+ start_time = time.time()
97
+
98
+ if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
99
+ result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)"
100
+ chunk_result['status'] = 'skipped'
101
+ else:
102
+ try:
103
+ summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
104
+ result = f"**Chunk {i+1}**:\n{summary}"
105
+ chunk_result['status'] = 'summarized'
106
+ except Exception as e:
107
+ result = f"**Chunk {i+1}**: Error: {str(e)}"
108
+ chunk_result['status'] = 'error'
109
+
110
+ chunk_result['time'] = time.time() - start_time
111
+ return result, chunk_result
112
+
113
+ def summarize_file(file_bytes):
114
+ start = time.time()
115
+ summaries = []
116
+ chunk_info = []
117
+
118
+ try:
119
+ doc = fitz.open(stream=file_bytes, filetype="pdf")
120
+ text = ""
121
+ for page in doc:
122
+ page_text = page.get_text("text")
123
+ if not page_text.strip():
124
+ continue
125
+ text += page_text
126
+ if len(text) > 30000:
127
+ text = text[:30000]
128
+ break
129
+ doc.close()
130
+
131
+ text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text)
132
+ text = "".join(c for c in text if ord(c) < 128)[:30000]
133
+ except Exception as e:
134
+ return f"Text extraction failed: {str(e)}", None, None
135
+
136
+ if not text.strip():
137
+ return "No text found", None, None
138
+
139
+ chunks = []
140
+ current_chunk = ""
141
+ for sentence in re.split(r'(?<=[.!?])\s+', text):
142
+ if len(current_chunk) + len(sentence) <= 1000:
143
+ current_chunk += sentence
144
+ else:
145
+ if current_chunk:
146
+ chunks.append(current_chunk)
147
+ current_chunk = sentence
148
+ if len(chunks) >= 30:
149
+ break
150
+ if current_chunk:
151
+ chunks.append(current_chunk)
152
+
153
+ max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4))
154
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
155
+ results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
156
+
157
+ for summary, info in results:
158
+ summaries.append(summary)
159
+ chunk_info.append(info)
160
+
161
+ final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
162
+ process_img = visualize_chunk_status(chunk_info)
163
+ flow_img = create_summary_flowchart(summaries)
164
+ return final_summary, process_img, flow_img
165
+
166
+ demo = gr.Interface(
167
+ fn=summarize_file,
168
+ inputs=gr.File(label="Upload PDF", type="binary"),
169
+ outputs=[
170
+ gr.Textbox(label="Summary", lines=15),
171
+ gr.Image(label="Chunk Status", type="pil"),
172
+ gr.Image(label="Key Points Flowchart", type="pil")
173
+ ],
174
+ title="PDF Summarizer",
175
+ description="Summarizes PDFs up to 30,000 characters with key point flowchart."
176
+ )
177
+
178
+ if __name__ == "__main__":
179
+ try:
180
+ logger.info("Starting Gradio on http://127.0.0.1:7860")
181
+ demo.launch(
182
+ share=False,
183
+ server_name="127.0.0.1",
184
+ server_port=7860,
185
+ debug=False
186
+ )
187
+ except Exception as e:
188
+ logger.error(f"Failed on port 7860: {str(e)}")
189
+ logger.info("Trying port 7861...")
190
+ try:
191
+ demo.launch(
192
+ share=False,
193
+ server_name="127.0.0.1",
194
+ server_port=7861,
195
+ debug=False
196
+ )
197
+ except Exception as e2:
198
+ logger.error(f"Failed on port 7861: {str(e2)}")
199
+ raise