examplefour / app.py
tejovanth's picture
Update app.py
f57bdb5 verified
import gradio as gr
import fitz # PyMuPDF
import torch
from transformers import pipeline
import time, logging, re
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
import numpy as np
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set device and optimize for speed
device = 0 if torch.cuda.is_available() else -1
logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}")
# Load model
try:
summarizer = pipeline(
"summarization",
model="t5-small",
device=device,
framework="pt",
torch_dtype=torch.float16 if device == 0 else torch.float32
)
except Exception as e:
logger.error(f"❌ Model loading failed: {str(e)}")
exit(1)
def visualize_chunk_status(chunk_data):
status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'}
labels = [f"C{i['chunk']}" for i in chunk_data]
colors = [status_colors.get(i['status'], 'gray') for i in chunk_data]
times = [i.get('time', 0.1) for i in chunk_data]
fig, ax = plt.subplots(figsize=(8, 2))
ax.barh(labels, times, color=colors, height=0.4)
ax.set_xlabel("Time (s)")
ax.set_title("Chunk Status")
plt.tight_layout(pad=0.5)
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def create_summary_flowchart(summaries):
# Filter valid summaries and extract key points
filtered = [
s for s in summaries
if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s
]
if not filtered:
return None
# Extract key points (first sentence or most important phrase)
key_points = []
for summary in filtered:
summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip()
# Take first sentence or truncate to 50 characters
first_sentence = re.split(r'(?<=[.!?])\s+', summary_text)[0][:50]
key_points.append(first_sentence + ("..." if len(first_sentence) >= 50 else ""))
# Create flowchart
fig_height = max(1.5, len(key_points) * 0.6)
fig, ax = plt.subplots(figsize=(6, fig_height))
ax.axis('off')
# Node positions and styling
ypos = np.arange(len(key_points) * 1.2, 0, -1.2)
boxprops = dict(boxstyle="round,pad=0.3", facecolor="lightgreen", edgecolor="black", alpha=0.9)
for i, (y, point) in enumerate(zip(ypos, key_points)):
# Draw node with key point
ax.text(0.5, y, point, ha='center', va='center', fontsize=9, bbox=boxprops)
# Draw arrows between nodes
if i < len(key_points) - 1:
ax.arrow(0.5, y - 0.3, 0, -0.9, head_width=0.02, head_length=0.1, fc='blue', ec='blue')
# Add title
ax.text(0.5, ypos[0] + 0.8, "Key Points Summary", ha='center', va='center', fontsize=12, weight='bold')
plt.tight_layout(pad=0.1)
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
plt.close(fig)
buf.seek(0)
return Image.open(buf)
def process_chunk(i, chunk):
chunk_result = {'chunk': i + 1, 'status': '', 'time': 0}
start_time = time.time()
if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5:
result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)"
chunk_result['status'] = 'skipped'
else:
try:
summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text']
result = f"**Chunk {i+1}**:\n{summary}"
chunk_result['status'] = 'summarized'
except Exception as e:
result = f"**Chunk {i+1}**: Error: {str(e)}"
chunk_result['status'] = 'error'
chunk_result['time'] = time.time() - start_time
return result, chunk_result
def summarize_file(file_bytes):
start = time.time()
summaries = []
chunk_info = []
try:
doc = fitz.open(stream=file_bytes, filetype="pdf")
text = ""
for page in doc:
page_text = page.get_text("text")
if not page_text.strip():
continue
text += page_text
if len(text) > 30000:
text = text[:30000]
break
doc.close()
text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text)
text = "".join(c for c in text if ord(c) < 128)[:30000]
except Exception as e:
return f"Text extraction failed: {str(e)}", None, None
if not text.strip():
return "No text found", None, None
chunks = []
current_chunk = ""
for sentence in re.split(r'(?<=[.!?])\s+', text):
if len(current_chunk) + len(sentence) <= 1000:
current_chunk += sentence
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if len(chunks) >= 30:
break
if current_chunk:
chunks.append(current_chunk)
max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks)))
for summary, info in results:
summaries.append(summary)
chunk_info.append(info)
final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries)
process_img = visualize_chunk_status(chunk_info)
flow_img = create_summary_flowchart(summaries)
return final_summary, process_img, flow_img
demo = gr.Interface(
fn=summarize_file,
inputs=gr.File(label="Upload PDF", type="binary"),
outputs=[
gr.Textbox(label="Summary", lines=15),
gr.Image(label="Chunk Status", type="pil"),
gr.Image(label="Key Points Flowchart", type="pil")
],
title="PDF Summarizer",
description="Summarizes PDFs up to 30,000 characters with key point flowchart."
)
if __name__ == "__main__":
try:
logger.info("Starting Gradio on http://127.0.0.1:7860")
demo.launch(
share=False,
server_name="127.0.0.1",
server_port=7860,
debug=False
)
except Exception as e:
logger.error(f"Failed on port 7860: {str(e)}")
logger.info("Trying port 7861...")
try:
demo.launch(
share=False,
server_name="127.0.0.1",
server_port=7861,
debug=False
)
except Exception as e2:
logger.error(f"Failed on port 7861: {str(e2)}")
raise