Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import fitz # PyMuPDF | |
| import torch | |
| from transformers import pipeline | |
| import time, logging, re | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| from concurrent.futures import ThreadPoolExecutor | |
| import numpy as np | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set device and optimize for speed | |
| device = 0 if torch.cuda.is_available() else -1 | |
| logger.info(f"🔧 Using {'GPU' if device == 0 else 'CPU'}") | |
| # Load model | |
| try: | |
| summarizer = pipeline( | |
| "summarization", | |
| model="t5-small", | |
| device=device, | |
| framework="pt", | |
| torch_dtype=torch.float16 if device == 0 else torch.float32 | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Model loading failed: {str(e)}") | |
| exit(1) | |
| def visualize_chunk_status(chunk_data): | |
| status_colors = {'summarized': 'green', 'skipped': 'orange', 'error': 'red'} | |
| labels = [f"C{i['chunk']}" for i in chunk_data] | |
| colors = [status_colors.get(i['status'], 'gray') for i in chunk_data] | |
| times = [i.get('time', 0.1) for i in chunk_data] | |
| fig, ax = plt.subplots(figsize=(8, 2)) | |
| ax.barh(labels, times, color=colors, height=0.4) | |
| ax.set_xlabel("Time (s)") | |
| ax.set_title("Chunk Status") | |
| plt.tight_layout(pad=0.5) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_summary_flowchart(summaries): | |
| # Filter valid summaries and extract key points | |
| filtered = [ | |
| s for s in summaries | |
| if s.startswith("**Chunk") and "Skipped" not in s and "Error" not in s | |
| ] | |
| if not filtered: | |
| return None | |
| # Extract key points (first sentence or most important phrase) | |
| key_points = [] | |
| for summary in filtered: | |
| summary_text = summary.split("**Chunk")[1].split("\n", 1)[-1].strip() | |
| # Take first sentence or truncate to 50 characters | |
| first_sentence = re.split(r'(?<=[.!?])\s+', summary_text)[0][:50] | |
| key_points.append(first_sentence + ("..." if len(first_sentence) >= 50 else "")) | |
| # Create flowchart | |
| fig_height = max(1.5, len(key_points) * 0.6) | |
| fig, ax = plt.subplots(figsize=(6, fig_height)) | |
| ax.axis('off') | |
| # Node positions and styling | |
| ypos = np.arange(len(key_points) * 1.2, 0, -1.2) | |
| boxprops = dict(boxstyle="round,pad=0.3", facecolor="lightgreen", edgecolor="black", alpha=0.9) | |
| for i, (y, point) in enumerate(zip(ypos, key_points)): | |
| # Draw node with key point | |
| ax.text(0.5, y, point, ha='center', va='center', fontsize=9, bbox=boxprops) | |
| # Draw arrows between nodes | |
| if i < len(key_points) - 1: | |
| ax.arrow(0.5, y - 0.3, 0, -0.9, head_width=0.02, head_length=0.1, fc='blue', ec='blue') | |
| # Add title | |
| ax.text(0.5, ypos[0] + 0.8, "Key Points Summary", ha='center', va='center', fontsize=12, weight='bold') | |
| plt.tight_layout(pad=0.1) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def process_chunk(i, chunk): | |
| chunk_result = {'chunk': i + 1, 'status': '', 'time': 0} | |
| start_time = time.time() | |
| if not chunk.strip() or sum(1 for c in chunk if not c.isalnum()) / len(chunk) > 0.5: | |
| result = f"**Chunk {i+1}**: Skipped (empty or equation-heavy)" | |
| chunk_result['status'] = 'skipped' | |
| else: | |
| try: | |
| summary = summarizer(chunk, max_length=60, min_length=10, do_sample=False)[0]['summary_text'] | |
| result = f"**Chunk {i+1}**:\n{summary}" | |
| chunk_result['status'] = 'summarized' | |
| except Exception as e: | |
| result = f"**Chunk {i+1}**: Error: {str(e)}" | |
| chunk_result['status'] = 'error' | |
| chunk_result['time'] = time.time() - start_time | |
| return result, chunk_result | |
| def summarize_file(file_bytes): | |
| start = time.time() | |
| summaries = [] | |
| chunk_info = [] | |
| try: | |
| doc = fitz.open(stream=file_bytes, filetype="pdf") | |
| text = "" | |
| for page in doc: | |
| page_text = page.get_text("text") | |
| if not page_text.strip(): | |
| continue | |
| text += page_text | |
| if len(text) > 30000: | |
| text = text[:30000] | |
| break | |
| doc.close() | |
| text = re.sub(r"\$\s*[^$]+\s*\$|\\cap|\s+", lambda m: "intersection" if m.group(0) == "\\cap" else " ", text) | |
| text = "".join(c for c in text if ord(c) < 128)[:30000] | |
| except Exception as e: | |
| return f"Text extraction failed: {str(e)}", None, None | |
| if not text.strip(): | |
| return "No text found", None, None | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in re.split(r'(?<=[.!?])\s+', text): | |
| if len(current_chunk) + len(sentence) <= 1000: | |
| current_chunk += sentence | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| current_chunk = sentence | |
| if len(chunks) >= 30: | |
| break | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| max_workers = min(8, max(2, torch.cuda.device_count() * 4 if device == 0 else 4)) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| results = list(executor.map(lambda ic: process_chunk(*ic), enumerate(chunks))) | |
| for summary, info in results: | |
| summaries.append(summary) | |
| chunk_info.append(info) | |
| final_summary = f"**Chunks**: {len(chunks)}\n**Time**: {time.time() - start:.2f}s\n\n" + "\n\n".join(summaries) | |
| process_img = visualize_chunk_status(chunk_info) | |
| flow_img = create_summary_flowchart(summaries) | |
| return final_summary, process_img, flow_img | |
| demo = gr.Interface( | |
| fn=summarize_file, | |
| inputs=gr.File(label="Upload PDF", type="binary"), | |
| outputs=[ | |
| gr.Textbox(label="Summary", lines=15), | |
| gr.Image(label="Chunk Status", type="pil"), | |
| gr.Image(label="Key Points Flowchart", type="pil") | |
| ], | |
| title="PDF Summarizer", | |
| description="Summarizes PDFs up to 30,000 characters with key point flowchart." | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| logger.info("Starting Gradio on http://127.0.0.1:7860") | |
| demo.launch( | |
| share=False, | |
| server_name="127.0.0.1", | |
| server_port=7860, | |
| debug=False | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed on port 7860: {str(e)}") | |
| logger.info("Trying port 7861...") | |
| try: | |
| demo.launch( | |
| share=False, | |
| server_name="127.0.0.1", | |
| server_port=7861, | |
| debug=False | |
| ) | |
| except Exception as e2: | |
| logger.error(f"Failed on port 7861: {str(e2)}") | |
| raise | |