AJC1's picture
Update app.py
67e85a7 verified
Raw
History Blame Contribute Delete
4.55 kB
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import pandas as pd
import os
# Public model path
model_hub_path = "AJC1/ag_news_distilbert_finetuned"
target_names = ["World", "Sports", "Business", "Sci/Tech"]
# Load Model
try:
tokenizer = AutoTokenizer.from_pretrained(model_hub_path)
model = AutoModelForSequenceClassification.from_pretrained(model_hub_path)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Prediction Functions
def predict_single_text(text):
"""The core logic: text -> dict of scores"""
if not text: return {}
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].tolist()
return {target_names[i]: v for i, v in enumerate(probs)}
def process_csv_file(file_obj):
"""The enterprise logic: CSV -> Classified CSV"""
try:
# Load the uploaded CSV
df = pd.read_csv(file_obj.name)
# Validation: Check if it has text
if df.empty:
return None, "Error: Uploaded file is empty."
# Smart column detection: Look for 'text', 'headline', or use the first column
target_col = None
for col in ['text', 'headline', 'title', 'content']:
if col in df.columns:
target_col = col
break
if not target_col:
target_col = df.columns[0] # Fallback to first column
# Run predictions
predicted_labels = []
confidence_scores = []
for text in df[target_col].astype(str):
scores = predict_single_text(text)
# Get the top label
top_label = max(scores, key=scores.get)
predicted_labels.append(top_label)
confidence_scores.append(f"{scores[top_label]:.2f}")
# Add results to dataframe
df['Predicted_Category'] = predicted_labels
df['Confidence'] = confidence_scores
# Save to a temporary output file
output_path = "classified_results.csv"
df.to_csv(output_path, index=False)
return output_path, f"Success! Processed {len(df)} rows. Download your results below."
except Exception as e:
return None, f"Error processing file: {str(e)}"
# The Professional Tabbed Interface
with gr.Blocks(title="AG News Enterprise Classifier") as demo:
gr.Markdown("#Automated News Routing System")
gr.Markdown("Select a workflow below: Single-item checking or Bulk file processing.")
with gr.Tabs():
# TAB 1: Single Input
with gr.TabItem("Live Check"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(lines=4, label="Input News Headline", placeholder="Paste text here...")
submit_btn = gr.Button("Classify Content", variant="primary")
with gr.Column():
label_output = gr.Label(num_top_classes=4, label="Category Prediction")
# Link functionality
submit_btn.click(fn=predict_single_text, inputs=text_input, outputs=label_output)
# Examples
gr.Examples(
examples=[
["Wall Street tumbles as tech stocks sell off."],
["Manchester United signs new striker for record fee."],
["NASA discovers water on Mars surface."]
],
inputs=text_input
)
# TAB 2: Batch Processing
with gr.TabItem("Bulk Analysis (CSV)"):
gr.Markdown("Upload a CSV file containing news headlines. The system will append a 'Category' column and return the file.")
with gr.Row():
file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
file_output = gr.File(label="Download Classified Results")
status_text = gr.Textbox(label="Status", interactive=False)
process_btn = gr.Button("Process Batch", variant="primary")
# Link functionality
process_btn.click(
fn=process_csv_file,
inputs=file_input,
outputs=[file_output, status_text]
)
if __name__ == "__main__":
demo.launch()