Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| # Define model names | |
| models = { | |
| "gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test", | |
| "ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test", | |
| "ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2", | |
| "ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender", | |
| "ModernBERT Large (gender)": "breadlicker45/ModernBERT-large-gender" | |
| } | |
| # Define the mapping for user-friendly labels | |
| label_map = { | |
| "LABEL_0": "Male (0)", | |
| "0": "Male (0)", | |
| "LABEL_1": "Female (1)", | |
| "1": "Female (1)" | |
| } | |
| # A cache to store loaded models/pipelines to speed up subsequent requests | |
| model_cache = {} | |
| # Determine the device to run on (GPU if available, otherwise CPU) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # The main classification function, now handles both model types | |
| def classify_text(model_name, text): | |
| try: | |
| processed_results = {} | |
| model_id = models[model_name] | |
| # --- SPECIAL HANDLING FOR THE GTE MODEL --- | |
| if "gte-gender" in model_id: | |
| # Check if model/tokenizer is already in our cache | |
| if model_id not in model_cache: | |
| print(f"Loading GTE model and tokenizer manually: {model_id}...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device) | |
| model_cache[model_id] = (model, tokenizer) # Cache both | |
| model, tokenizer = model_cache[model_id] | |
| # Tokenize the input text and move to the correct device | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
| # Get model predictions | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Convert logits to probabilities using softmax | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| # Format results to match the pipeline's output style | |
| processed_results[label_map["LABEL_0"]] = probabilities[0].item() | |
| processed_results[label_map["LABEL_1"]] = probabilities[1].item() | |
| # --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS --- | |
| else: | |
| # Check if the pipeline is already in our cache | |
| if model_id not in model_cache: | |
| print(f"Loading pipeline for model: {model_id}...") | |
| # Load and cache the pipeline | |
| model_cache[model_id] = pipeline( | |
| "text-classification", | |
| model=model_id, | |
| top_k=None, | |
| device=device # Use the determined device | |
| ) | |
| classifier = model_cache[model_id] | |
| predictions = classifier(text) | |
| # Process predictions to use friendly labels | |
| if predictions and isinstance(predictions, list) and predictions[0]: | |
| for pred in predictions[0]: | |
| raw_label = pred["label"] | |
| score = pred["score"] | |
| friendly_label = label_map.get(raw_label, raw_label) | |
| processed_results[friendly_label] = score | |
| return processed_results | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Return an error message suitable for gr.Label or gr.JSON | |
| return {"Error": f"Failed to process: {e}"} | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=classify_text, | |
| inputs=[ | |
| gr.Dropdown( | |
| list(models.keys()), | |
| label="Select Model", | |
| value="gte base (gender v3.1)" # Default model | |
| ), | |
| gr.Textbox( | |
| lines=2, | |
| placeholder="Enter text to classify for perceived gender...", | |
| value="This is an example sentence." | |
| ) | |
| ], | |
| # Since we now consistently return a dictionary of {label: score}, | |
| # we can go back to using the nicer-looking gr.Label component! | |
| outputs=gr.Label(num_top_classes=2, label="Classification Results"), | |
| title="ModernBERT & GTE Gender Classifier", | |
| description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.", | |
| allow_flagging="never", | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface.launch() |