Yatheshr's picture
Update app.py
3fe399d verified
# This app allows to upload PDF files,
# It compares the speed and output of two versions of a machine learning model (original and optimized/quantized) to classify the content in the PDFs.
import gradio as gr
import torch
import fitz # PyMuPDF for PDF text extraction
import pandas as pd
import time
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch.quantization import quantize_dynamic
# Load tokenizer and models
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
original_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
quantized_model = quantize_dynamic(original_model, {torch.nn.Linear}, dtype=torch.qint8)
def extract_text_from_pdf(pdf_file):
doc = fitz.open(pdf_file.name) # βœ” Use file path from Gradio
text = ""
for page in doc:
text += page.get_text()
return text.strip()
def classify_text(text, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
prediction = torch.argmax(outputs.logits, dim=1).item()
return prediction
def analyze_pdfs(pdf_files):
if not pdf_files:
return "⚠️ Please upload at least one PDF file.", None, None
results = []
skipped_files = []
for pdf in pdf_files:
filename = pdf.name
text = extract_text_from_pdf(pdf)
if not text.strip():
skipped_files.append(filename)
continue
start1 = time.time()
pred_orig = classify_text(text, original_model)
time1 = round(time.time() - start1, 3)
start2 = time.time()
pred_quant = classify_text(text, quantized_model)
time2 = round(time.time() - start2, 3)
results.append({
"File": filename,
"Original_Model_Prediction": pred_orig,
"Original_Time(s)": time1,
"Quantized_Model_Prediction": pred_quant,
"Quantized_Time(s)": time2
})
if not results:
return "⚠️ No valid text found in any PDF.", None, None
df = pd.DataFrame(results)
csv_path = "model_comparison.csv"
df.to_csv(csv_path, index=False)
message = f"βœ… Processed {len(results)} file(s)."
if skipped_files:
message += f" Skipped (no text): {', '.join(skipped_files)}"
return message, df, csv_path
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## πŸ“Š Compare Original vs Quantized DistilBERT Model on PDFs")
pdf_input = gr.File(label="Upload PDF(s)", file_types=[".pdf"], file_count="multiple")
run_button = gr.Button("Run Analysis")
status = gr.Textbox(label="Status", interactive=False)
output_table = gr.Dataframe(label="Results")
download_link = gr.File(label="Download CSV")
run_button.click(
fn=analyze_pdfs,
inputs=[pdf_input],
outputs=[status, output_table, download_link] # βœ” Now correctly expects 3 outputs
)
demo.launch()
# Below, upload for 2 pdf as input. With this ony 1 output.
#import gradio as gr
#import torch
#import fitz # PyMuPDF
#import pandas as pd
#import time
#import os
#
#from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
#from torch.quantization import quantize_dynamic
#
## Load tokenizer and models
#tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
#original_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")
#
## Quantize model to reduce size and improve speed
#quantized_model = quantize_dynamic(original_model, {torch.nn.Linear}, dtype=torch.qint8)
#
## PDF text extraction
#def extract_text_from_pdf(pdf_path):
# doc = fitz.open(pdf_path)
# text = ""
# for page in doc:
# text += page.get_text()
# return text.strip()
#
## Classify text
#def classify_text(text, model):
# inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# with torch.no_grad():
# outputs = model(**inputs)
# prediction = torch.argmax(outputs.logits, dim=1).item()
# return prediction # 0 or 1
#
## Analyze uploaded PDFs
#def analyze_pdfs(pdf_files):
# results = []
#
# for pdf in pdf_files:
# filename = os.path.basename(pdf.name)
# text = extract_text_from_pdf(pdf.name)
# if not text:
# continue
#
# # Original model
# start1 = time.time()
# pred_orig = classify_text(text, original_model)
# time1 = round(time.time() - start1, 3)
#
# # Quantized model
# start2 = time.time()
# pred_quant = classify_text(text, quantized_model)
# time2 = round(time.time() - start2, 3)
#
# results.append({
# "File": filename,
# "Original_Model_Prediction": pred_orig,
# "Original_Time(s)": time1,
# "Quantized_Model_Prediction": pred_quant,
# "Quantized_Time(s)": time2
# })
#
# df = pd.DataFrame(results)
# csv_path = "model_comparison.csv" # Saved in current working directory
# df.to_csv(csv_path, index=False)
# return df, csv_path
#
## Gradio UI
#with gr.Blocks() as demo:
# gr.Markdown("## πŸ“Š Compare Original vs Quantized BERT Model on PDFs")
# with gr.Row():
# pdf_input = gr.File(label="Upload PDF(s)", file_types=[".pdf"], file_count="multiple")
# run_button = gr.Button("Run Analysis")
# output_table = gr.Dataframe(label="Results")
# download_link = gr.File(label="Download CSV")
#
# run_button.click(fn=analyze_pdfs, inputs=[pdf_input], outputs=[output_table, download_link])
#
#demo.launch()