import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import pandas as pd import os # Public model path model_hub_path = "AJC1/ag_news_distilbert_finetuned" target_names = ["World", "Sports", "Business", "Sci/Tech"] # Load Model try: tokenizer = AutoTokenizer.from_pretrained(model_hub_path) model = AutoModelForSequenceClassification.from_pretrained(model_hub_path) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") # Prediction Functions def predict_single_text(text): """The core logic: text -> dict of scores""" if not text: return {} inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=1)[0].tolist() return {target_names[i]: v for i, v in enumerate(probs)} def process_csv_file(file_obj): """The enterprise logic: CSV -> Classified CSV""" try: # Load the uploaded CSV df = pd.read_csv(file_obj.name) # Validation: Check if it has text if df.empty: return None, "Error: Uploaded file is empty." # Smart column detection: Look for 'text', 'headline', or use the first column target_col = None for col in ['text', 'headline', 'title', 'content']: if col in df.columns: target_col = col break if not target_col: target_col = df.columns[0] # Fallback to first column # Run predictions predicted_labels = [] confidence_scores = [] for text in df[target_col].astype(str): scores = predict_single_text(text) # Get the top label top_label = max(scores, key=scores.get) predicted_labels.append(top_label) confidence_scores.append(f"{scores[top_label]:.2f}") # Add results to dataframe df['Predicted_Category'] = predicted_labels df['Confidence'] = confidence_scores # Save to a temporary output file output_path = "classified_results.csv" df.to_csv(output_path, index=False) return output_path, f"Success! Processed {len(df)} rows. Download your results below." except Exception as e: return None, f"Error processing file: {str(e)}" # The Professional Tabbed Interface with gr.Blocks(title="AG News Enterprise Classifier") as demo: gr.Markdown("#Automated News Routing System") gr.Markdown("Select a workflow below: Single-item checking or Bulk file processing.") with gr.Tabs(): # TAB 1: Single Input with gr.TabItem("Live Check"): with gr.Row(): with gr.Column(): text_input = gr.Textbox(lines=4, label="Input News Headline", placeholder="Paste text here...") submit_btn = gr.Button("Classify Content", variant="primary") with gr.Column(): label_output = gr.Label(num_top_classes=4, label="Category Prediction") # Link functionality submit_btn.click(fn=predict_single_text, inputs=text_input, outputs=label_output) # Examples gr.Examples( examples=[ ["Wall Street tumbles as tech stocks sell off."], ["Manchester United signs new striker for record fee."], ["NASA discovers water on Mars surface."] ], inputs=text_input ) # TAB 2: Batch Processing with gr.TabItem("Bulk Analysis (CSV)"): gr.Markdown("Upload a CSV file containing news headlines. The system will append a 'Category' column and return the file.") with gr.Row(): file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) file_output = gr.File(label="Download Classified Results") status_text = gr.Textbox(label="Status", interactive=False) process_btn = gr.Button("Process Batch", variant="primary") # Link functionality process_btn.click( fn=process_csv_file, inputs=file_input, outputs=[file_output, status_text] ) if __name__ == "__main__": demo.launch()