Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| import numpy as np | |
| from typing import Dict, List, Tuple, Optional | |
| import io | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from collections import defaultdict | |
| import json | |
| import traceback | |
| import spaces # Import the spaces library | |
| import tempfile | |
| from dotenv import load_dotenv | |
| import os | |
| token_hf = os.getenv('HF_TOKEN') | |
| load_dotenv() | |
| class MultiClientThemeClassifier: | |
| def __init__(self): | |
| self.model = None | |
| self.client_themes = {} | |
| self.model_loaded = False | |
| self.default_model = 'google/embeddinggemma-300m' | |
| self.current_model_name = self.default_model | |
| def load_model(self, model_name: str): | |
| """Load the embedding model onto the GPU, remembering the choice.""" | |
| try: | |
| # Prevent reloading the same model | |
| if self.model_loaded and self.current_model_name == model_name: | |
| return f"β Model '{model_name}' is already loaded." | |
| self.model = None | |
| self.client_themes = {} | |
| self.model_loaded = False | |
| print(f"Loading model: {model_name} onto CUDA device") | |
| self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True,token=token_hf) | |
| self.model_loaded = True | |
| self.current_model_name = model_name | |
| return f"β Model '{model_name}' loaded successfully onto GPU!" | |
| except Exception as e: | |
| self.model_loaded = False | |
| error_details = traceback.format_exc() | |
| return f"β Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}" | |
| def _ensure_model_is_loaded(self) -> Optional[str]: | |
| """Internal helper to load the correct model if it's not already loaded.""" | |
| if not self.model_loaded: | |
| print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...") | |
| status = self.load_model(self.current_model_name) | |
| if "Error" in status: | |
| return status | |
| return None | |
| def add_client_themes(self, client_id: str, themes: List[str]): | |
| """Add themes for a specific client""" | |
| error_status = self._ensure_model_is_loaded() | |
| if error_status: return error_status | |
| try: | |
| self.client_themes[client_id] = {} | |
| for theme in themes: | |
| prototype = self.model.encode(theme, convert_to_tensor=True) | |
| self.client_themes[client_id][theme] = prototype | |
| return f"β Added {len(themes)} themes for client '{client_id}'" | |
| except Exception as e: | |
| return f"β Error adding themes: {str(e)}" | |
| def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]: | |
| """Classify a single text for a specific client""" | |
| error_status = self._ensure_model_is_loaded() | |
| if error_status: return f"Error: {error_status}", 0.0, {} | |
| if client_id not in self.client_themes: | |
| return "Client not found", 0.0, {} | |
| try: | |
| text_embedding = self.model.encode(text, convert_to_tensor=True) | |
| similarities = {theme: util.cos_sim(text_embedding, prototype).item() | |
| for theme, prototype in self.client_themes[client_id].items()} | |
| if not similarities: return "No themes for client", 0.0, {} | |
| best_theme = max(similarities, key=similarities.get) | |
| best_score = similarities[best_theme] | |
| if best_score < confidence_threshold: | |
| return "UNKNOWN_THEME", best_score, similarities | |
| return best_theme, best_score, similarities | |
| except Exception as e: | |
| return f"Error: {str(e)}", 0.0, {} | |
| # CORRECTED: The benchmark function now takes the model_name as an argument | |
| def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]: | |
| """Benchmark a specific model on a CSV file.""" | |
| # Step 1: Explicitly load the model requested by the user for this benchmark run. | |
| load_status = self.load_model(model_name) | |
| # We allow the function to proceed if the model is "already loaded", but stop for any other error. | |
| if "β" in load_status: | |
| return f"β Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None | |
| # Step 2: Proceed with the benchmark logic as before. | |
| encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1'] | |
| df = None | |
| for encoding in encodings_to_try: | |
| try: | |
| df = pd.read_csv(csv_filepath, encoding=encoding) | |
| print(f"Successfully read CSV with encoding: {encoding}") | |
| break | |
| except (UnicodeDecodeError, pd.errors.ParserError): | |
| continue | |
| if df is None: | |
| return "β Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None | |
| try: | |
| if 'text' not in df.columns or 'real_tag' not in df.columns: | |
| return f"β CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None | |
| df.dropna(subset=['text', 'real_tag'], inplace=True) | |
| df['text'] = df['text'].astype(str) | |
| df['real_tag'] = df['real_tag'].astype(str) | |
| unique_themes = df['real_tag'].unique().tolist() | |
| self.add_client_themes(client_id, unique_themes) | |
| texts = df['text'].str.slice(0, 500).tolist() | |
| results = [self.classify_text(text, client_id) for text in texts] | |
| df['predicted_tag'] = [res[0] for res in results] | |
| df['confidence'] = [res[1] for res in results] | |
| correct = (df['real_tag'] == df['predicted_tag']).sum() | |
| total = len(df) | |
| accuracy = correct / total if total > 0 else 0 | |
| results_summary = f"π **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})" | |
| fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'}) | |
| visualization_html = fig.to_html() | |
| temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name | |
| df.to_csv(temp_file_path, index=False) | |
| return results_summary, temp_file_path, visualization_html | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| return f"β Error during benchmarking: {str(e)}\n\n{error_details}", None, None | |
| # Initialize the classifier | |
| classifier = MultiClientThemeClassifier() | |
| def load_model_interface(model_name: str): | |
| return classifier.load_model(model_name.strip()) | |
| def add_themes_interface(client_id: str, themes_text: str): | |
| if not themes_text.strip(): return "β Please enter themes!" | |
| themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()] | |
| return classifier.add_client_themes(client_id, themes) | |
| def classify_interface(text: str, client_id: str, confidence_threshold: float): | |
| if not text.strip(): return "Please enter text to classify!", "" | |
| pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold) | |
| sim_display = "**Similarity Scores:**\n" + "\n".join([f"- {theme}: {sim:.3f}" for theme, sim in sorted(similarities.items(), key=lambda x: x[1], reverse=True)]) | |
| result = f"π― **Predicted Theme:** {pred_theme}\nπ₯ **Confidence:** {confidence:.3f}\n\n{sim_display}" | |
| return result, "" | |
| # CORRECTED: The interface now accepts model_name | |
| def benchmark_interface(csv_file_obj, client_id: str, model_name: str): | |
| if csv_file_obj is None: | |
| return "Please upload a CSV file!", None, None | |
| if not model_name.strip(): | |
| return "Please enter a model name for the benchmark!", None, None | |
| try: | |
| csv_filepath = csv_file_obj.name | |
| # Pass the model name from the UI down to the classifier method | |
| return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip()) | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| return f"β Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π― Custom Themes Classification - MVP") | |
| with gr.Tab("π Setup & Model"): | |
| gr.Markdown("### Step 1: Load the Embedding Model (Optional)") | |
| gr.Markdown("A default model (`google/embeddinggemma-300m`) will load automatically on first use. You can specify a different model here to use it in other tabs.") | |
| with gr.Row(): | |
| # This input is now used by the benchmark tab as well | |
| model_input = gr.Textbox(label="HuggingFace Model Name", value="google/embeddinggemma-300m") | |
| load_btn = gr.Button("Load Model", variant="primary") | |
| load_status = gr.Textbox(label="Status", interactive=False) | |
| load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status) | |
| gr.Markdown("### Step 2: Add Themes for a Client") | |
| with gr.Row(): | |
| client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1") | |
| themes_input = gr.Textbox(label="Themes (one per line)", lines=5) | |
| add_themes_btn = gr.Button("Add Themes", variant="secondary") | |
| themes_status = gr.Textbox(label="Status", interactive=False) | |
| add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status) | |
| with gr.Tab("π Single Text Classification"): | |
| gr.Markdown("### Classify Individual Posts") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox(label="Text to Classify", lines=3) | |
| client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1") | |
| confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold") | |
| classify_btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| classification_result = gr.Markdown(label="Results") | |
| classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)]) | |
| with gr.Tab("π CSV Benchmarking"): | |
| gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"]) | |
| benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client") | |
| benchmark_btn = gr.Button("Run Benchmark", variant="primary") | |
| with gr.Column(): | |
| benchmark_results = gr.Markdown(label="Benchmark Results") | |
| with gr.Row(): | |
| results_csv = gr.File(label="Download Detailed Results", interactive=False) | |
| visualization = gr.HTML(label="Visualization") | |
| # CORRECTED: The button now sends the model_input value to the benchmark function | |
| benchmark_btn.click( | |
| benchmark_interface, | |
| inputs=[csv_upload, benchmark_client, model_input], | |
| outputs=[benchmark_results, results_csv, visualization] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |