import gradio as gr import fitz # PyMuPDF import torch from transformers import pipeline import time, logging, re import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import io from PIL import Image from concurrent.futures import ThreadPoolExecutor import numpy as np logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set device and optimize for speed device = 0 if torch.cuda.is_available() else -1 logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}") # Load model try: summarizer = pipeline( "summarization", model="t5-small", device=device, framework="pt", torch_dtype=torch.float16 if device == 0 else torch.float32 ) except Exception as e: logger.error(f"❌ Model loading failed: {str(e)}") exit(1) def visualize_chunk_status(chunk_data): status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'} labels = [f"C{i['chunk']}" for i in chunk_data] colors = [status_colors.get(i['status'], 'gray') for i in chunk_data] times = [i.get('time', 0.1) for i in chunk_data] fig, ax = plt.subplots(figsize=(8, 2)) ax.barh(labels, times, color=colors, height=0.4) ax.set_xlabel("Time (s)") ax.set_title("Chunk Status") plt.tight_layout(pad=0.5) buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100) plt.close(fig) buf.seek(0) return Image.open(buf) def create_summary_flowchart(summaries): # Filter valid summaries and extract key points filtered = [ s for s in summaries if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s ] if not filtered: return None # Extract key points (first sentence or most important phrase) key_points = [] for summary in filtered: summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip() # Take first sentence or truncate to 50 characters first_sentence = re.split(r'(?<=[.!?])\s+', summary_text)[0][:50] key_points.append(first_sentence + ("..." if len(first_sentence) >= 50 else "")) # Create flowchart fig_height = max(1.5, len(key_points) * 0.6) fig, ax = plt.subplots(figsize=(6, fig_height)) ax.axis('off') # Node positions and styling ypos = np.arange(len(key_points) * 1.2, 0, -1.2) boxprops = dict(boxstyle="round,pad=0.3", facecolor="lightgreen", edgecolor="black", alpha=0.9) for i, (y, point) in enumerate(zip(ypos, key_points)): # Draw node with key point ax.text(0.5, y, point, ha='center', va='center', fontsize=9, bbox=boxprops) # Draw arrows between nodes if i < len(key_points) - 1: ax.arrow(0.5, y - 0.3, 0, -0.9, head_width=0.02, head_length=0.1, fc='blue', ec='blue') # Add title ax.text(0.5, ypos[0] + 0.8, "Key Points Summary", ha='center', va='center', fontsize=12, weight='bold') plt.tight_layout(pad=0.1) buf = io.BytesIO() fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') plt.close(fig) buf.seek(0) return Image.open(buf) def process_chunk(i, chunk): chunk_result = {'chunk': i + 1, 'status': '', 'time': 0} start_time = time.time() if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5: result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)" chunk_result['status'] = 'skipped' else: try: summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text'] result = f"**Chunk {i+1}**:\n{summary}" chunk_result['status'] = 'summarized' except Exception as e: result = f"**Chunk {i+1}**: Error: {str(e)}" chunk_result['status'] = 'error' chunk_result['time'] = time.time() - start_time return result, chunk_result def summarize_file(file_bytes): start = time.time() summaries = [] chunk_info = [] try: doc = fitz.open(stream=file_bytes, filetype="pdf") text = "" for page in doc: page_text = page.get_text("text") if not page_text.strip(): continue text += page_text if len(text) > 30000: text = text[:30000] break doc.close() text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text) text = "".join(c for c in text if ord(c) < 128)[:30000] except Exception as e: return f"Text extraction failed: {str(e)}", None, None if not text.strip(): return "No text found", None, None chunks = [] current_chunk = "" for sentence in re.split(r'(?<=[.!?])\s+', text): if len(current_chunk) + len(sentence) <= 1000: current_chunk += sentence else: if current_chunk: chunks.append(current_chunk) current_chunk = sentence if len(chunks) >= 30: break if current_chunk: chunks.append(current_chunk) max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4)) with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks))) for summary, info in results: summaries.append(summary) chunk_info.append(info) final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries) process_img = visualize_chunk_status(chunk_info) flow_img = create_summary_flowchart(summaries) return final_summary, process_img, flow_img demo = gr.Interface( fn=summarize_file, inputs=gr.File(label="Upload PDF", type="binary"), outputs=[ gr.Textbox(label="Summary", lines=15), gr.Image(label="Chunk Status", type="pil"), gr.Image(label="Key Points Flowchart", type="pil") ], title="PDF Summarizer", description="Summarizes PDFs up to 30,000 characters with key point flowchart." ) if __name__ == "__main__": try: logger.info("Starting Gradio on http://127.0.0.1:7860") demo.launch( share=False, server_name="127.0.0.1", server_port=7860, debug=False ) except Exception as e: logger.error(f"Failed on port 7860: {str(e)}") logger.info("Trying port 7861...") try: demo.launch( share=False, server_name="127.0.0.1", server_port=7861, debug=False ) except Exception as e2: logger.error(f"Failed on port 7861: {str(e2)}") raise