Spaces:
Runtime error
Runtime error
File size: 6,788 Bytes
fa7e538 f57bdb5 97db6a8 f57bdb5 f45e94d 25abcb8 f57bdb5 97db6a8 f57bdb5 fa7e538 f57bdb5 25abcb8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 25abcb8 97db6a8 f57bdb5 25abcb8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 25abcb8 f57bdb5 25abcb8 97db6a8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 97db6a8 f57bdb5 97db6a8 fa7e538 f57bdb5 ba5944b 2d57a07 f45e94d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import gradio as gr
import fitz # PyMuPDF
import torch
from transformers import pipeline
import time, logging, re
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
import numpy as np
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set device and optimize for speed
device = 0 if torch.cuda.is_available() else -1
logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}")
# Load model
try:
summarizer = pipeline(
"summarization",
model="t5-small",
device=device,
framework="pt",
torch_dtype=torch.float16 if device == 0 else torch.float32
)
except Exception as e:
logger.error(f"❌ Model loading failed: {str(e)}")
exit(1)
def visualize_chunk_status(chunk_data):
status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'}
labels = [f"C{i['chunk']}" for i in chunk_data]
colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
times = [i.get('time', 0.1) for i in chunk_data]
fig, ax = plt.subplots(figsize=(8, 2))
ax.barh(labels, times, color=colors, height=0.4)
ax.set_xlabel("Time (s)")
ax.set_title("Chunk Status")
plt.tight_layout(pad=0.5)
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def create_summary_flowchart(summaries):
# Filter valid summaries and extract key points
filtered = [
s for s in summaries
if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s
]
if not filtered:
return None
# Extract key points (first sentence or most important phrase)
key_points = []
for summary in filtered:
summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip()
# Take first sentence or truncate to 50 characters
first_sentence = re.split(r'(?<=[.!?])\s+', summary_text)[0][:50]
key_points.append(first_sentence + ("..." if len(first_sentence) >= 50 else ""))
# Create flowchart
fig_height = max(1.5, len(key_points) * 0.6)
fig, ax = plt.subplots(figsize=(6, fig_height))
ax.axis('off')
# Node positions and styling
ypos = np.arange(len(key_points) * 1.2, 0, -1.2)
boxprops = dict(boxstyle="round,pad=0.3", facecolor="lightgreen", edgecolor="black", alpha=0.9)
for i, (y, point) in enumerate(zip(ypos, key_points)):
# Draw node with key point
ax.text(0.5, y, point, ha='center', va='center', fontsize=9, bbox=boxprops)
# Draw arrows between nodes
if i < len(key_points) - 1:
ax.arrow(0.5, y - 0.3, 0, -0.9, head_width=0.02, head_length=0.1, fc='blue', ec='blue')
# Add title
ax.text(0.5, ypos[0] + 0.8, "Key Points Summary", ha='center', va='center', fontsize=12, weight='bold')
plt.tight_layout(pad=0.1)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def process_chunk(i, chunk):
chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
start_time = time.time()
if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)"
chunk_result['status'] = 'skipped'
else:
try:
summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
result = f"**Chunk {i+1}**:\n{summary}"
chunk_result['status'] = 'summarized'
except Exception as e:
result = f"**Chunk {i+1}**: Error: {str(e)}"
chunk_result['status'] = 'error'
chunk_result['time'] = time.time() - start_time
return result, chunk_result
def summarize_file(file_bytes):
start = time.time()
summaries = []
chunk_info = []
try:
doc = fitz.open(stream=file_bytes, filetype="pdf")
text = ""
for page in doc:
page_text = page.get_text("text")
if not page_text.strip():
continue
text += page_text
if len(text) > 30000:
text = text[:30000]
break
doc.close()
text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text)
text = "".join(c for c in text if ord(c) < 128)[:30000]
except Exception as e:
return f"Text extraction failed: {str(e)}", None, None
if not text.strip():
return "No text found", None, None
chunks = []
current_chunk = ""
for sentence in re.split(r'(?<=[.!?])\s+', text):
if len(current_chunk) + len(sentence) <= 1000:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if len(chunks) >= 30:
break
if current_chunk:
chunks.append(current_chunk)
max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
for summary, info in results:
summaries.append(summary)
chunk_info.append(info)
final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
process_img = visualize_chunk_status(chunk_info)
flow_img = create_summary_flowchart(summaries)
return final_summary, process_img, flow_img
demo = gr.Interface(
fn=summarize_file,
inputs=gr.File(label="Upload PDF", type="binary"),
outputs=[
gr.Textbox(label="Summary", lines=15),
gr.Image(label="Chunk Status", type="pil"),
gr.Image(label="Key Points Flowchart", type="pil")
],
title="PDF Summarizer",
description="Summarizes PDFs up to 30,000 characters with key point flowchart."
)
if __name__ == "__main__":
try:
logger.info("Starting Gradio on http://127.0.0.1:7860")
demo.launch(
share=False,
server_name="127.0.0.1",
server_port=7860,
debug=False
)
except Exception as e:
logger.error(f"Failed on port 7860: {str(e)}")
logger.info("Trying port 7861...")
try:
demo.launch(
share=False,
server_name="127.0.0.1",
server_port=7861,
debug=False
)
except Exception as e2:
logger.error(f"Failed on port 7861: {str(e2)}")
raise
|