ziaulkarim245's picture
Update app.py
31eb42c verified
import gradio as gr
import torch
import nltk
import numpy as np
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
# SETUP & NLTK
nltk_data_dir = os.path.join(os.getcwd(), "nltk_data")
os.makedirs(nltk_data_dir, exist_ok=True)
nltk.data.path.append(nltk_data_dir)
try:
nltk.download('punkt', download_dir=nltk_data_dir)
except Exception as e:
print(f"NLTK Warning: {e}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(" Loading Fine-Tuned Models...")
model_name = "ziaulkarim245/bart-large-cnn-Text-Summarizer"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
selector_model = SentenceTransformer('all-mpnet-base-v2', device=device)
# HYBRID SELECTOR
def hybrid_selector(article_text, max_sentences=25):
try:
sentences = nltk.sent_tokenize(article_text)
if len(sentences) <= max_sentences:
return article_text
embeddings = selector_model.encode(sentences)
n_clusters = min(max_sentences, len(sentences))
kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
kmeans.fit(embeddings)
selected_indices = []
for i in range(n_clusters):
center = kmeans.cluster_centers_[i]
distances = np.linalg.norm(embeddings - center, axis=1)
selected_indices.append(np.argmin(distances))
selected_indices.sort()
return " ".join([sentences[i] for i in selected_indices])
except:
return article_text
# SMART GENERATION
def smart_summary(text, mode):
refined_text = hybrid_selector(text)
inputs = tokenizer(refined_text, return_tensors="pt", max_length=1024, truncation=True).to(device)
input_len = inputs["input_ids"].shape[1]
if mode == "⚑ Quick Scan":
target_max, target_min, beams, penalty = 150, 40, 2, 2.5
elif mode == "πŸ“ Professional Brief":
target_max, target_min, beams, penalty = 256, 80, 4, 2.0
elif mode == "🧐 Deep Dive":
target_max, target_min, beams, penalty = 500, 200, 6, 1.0
if input_len < target_min:
final_min = int(input_len * 0.5)
final_max = input_len
else:
final_min = target_min
final_max = target_max
summary_ids = model.generate(
inputs["input_ids"],
max_length=final_max,
min_length=final_min,
num_beams=beams,
length_penalty=penalty,
early_stopping=True
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# INTERFACE
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Smart Summarizer")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(lines=12, label="Input Article", placeholder="Paste text here...")
mode_select = gr.Radio(
["⚑ Quick Scan", "πŸ“ Professional Brief", "🧐 Deep Dive"],
label="Summary Style",
value="πŸ“ Professional Brief"
)
btn = gr.Button("✨ Summarize", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(lines=15, label="AI Summary")
btn.click(smart_summary, inputs=[input_text, mode_select], outputs=output_text)
demo.launch()