Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import torch | |
| import pandas as pd | |
| import os | |
| # Public model path | |
| model_hub_path = "AJC1/ag_news_distilbert_finetuned" | |
| target_names = ["World", "Sports", "Business", "Sci/Tech"] | |
| # Load Model | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_hub_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_hub_path) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Prediction Functions | |
| def predict_single_text(text): | |
| """The core logic: text -> dict of scores""" | |
| if not text: return {} | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=1)[0].tolist() | |
| return {target_names[i]: v for i, v in enumerate(probs)} | |
| def process_csv_file(file_obj): | |
| """The enterprise logic: CSV -> Classified CSV""" | |
| try: | |
| # Load the uploaded CSV | |
| df = pd.read_csv(file_obj.name) | |
| # Validation: Check if it has text | |
| if df.empty: | |
| return None, "Error: Uploaded file is empty." | |
| # Smart column detection: Look for 'text', 'headline', or use the first column | |
| target_col = None | |
| for col in ['text', 'headline', 'title', 'content']: | |
| if col in df.columns: | |
| target_col = col | |
| break | |
| if not target_col: | |
| target_col = df.columns[0] # Fallback to first column | |
| # Run predictions | |
| predicted_labels = [] | |
| confidence_scores = [] | |
| for text in df[target_col].astype(str): | |
| scores = predict_single_text(text) | |
| # Get the top label | |
| top_label = max(scores, key=scores.get) | |
| predicted_labels.append(top_label) | |
| confidence_scores.append(f"{scores[top_label]:.2f}") | |
| # Add results to dataframe | |
| df['Predicted_Category'] = predicted_labels | |
| df['Confidence'] = confidence_scores | |
| # Save to a temporary output file | |
| output_path = "classified_results.csv" | |
| df.to_csv(output_path, index=False) | |
| return output_path, f"Success! Processed {len(df)} rows. Download your results below." | |
| except Exception as e: | |
| return None, f"Error processing file: {str(e)}" | |
| # The Professional Tabbed Interface | |
| with gr.Blocks(title="AG News Enterprise Classifier") as demo: | |
| gr.Markdown("#Automated News Routing System") | |
| gr.Markdown("Select a workflow below: Single-item checking or Bulk file processing.") | |
| with gr.Tabs(): | |
| # TAB 1: Single Input | |
| with gr.TabItem("Live Check"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox(lines=4, label="Input News Headline", placeholder="Paste text here...") | |
| submit_btn = gr.Button("Classify Content", variant="primary") | |
| with gr.Column(): | |
| label_output = gr.Label(num_top_classes=4, label="Category Prediction") | |
| # Link functionality | |
| submit_btn.click(fn=predict_single_text, inputs=text_input, outputs=label_output) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Wall Street tumbles as tech stocks sell off."], | |
| ["Manchester United signs new striker for record fee."], | |
| ["NASA discovers water on Mars surface."] | |
| ], | |
| inputs=text_input | |
| ) | |
| # TAB 2: Batch Processing | |
| with gr.TabItem("Bulk Analysis (CSV)"): | |
| gr.Markdown("Upload a CSV file containing news headlines. The system will append a 'Category' column and return the file.") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload CSV File", file_types=[".csv"]) | |
| file_output = gr.File(label="Download Classified Results") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| process_btn = gr.Button("Process Batch", variant="primary") | |
| # Link functionality | |
| process_btn.click( | |
| fn=process_csv_file, | |
| inputs=file_input, | |
| outputs=[file_output, status_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |