File size: 4,464 Bytes
39b12ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import os
# --- 1. Load Model from Hugging Face Hub ---
# Get the Hugging Face token from the Space's secrets
# This is crucial for accessing a private model
HF_TOKEN = os.getenv("HF_TOKEN")
# Ensure the token is set
if HF_TOKEN is None:
raise ValueError(
"Hugging Face token not found. Please set the HF_TOKEN secret in your Space settings."
)
# The ID of your private model on the Hub
MODEL_ID = "breadlicker45/bilingual-large-gender-v4-test"
print(f"Loading model: {MODEL_ID}...")
try:
# Explicitly load tokenizer and model to pass the token and trust_remote_code
# trust_remote_code=True is needed for models with custom architectures/code
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
trust_remote_code=True # IMPORTANT for custom models
)
# Create the pipeline using the pre-loaded model and tokenizer
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer
)
print("Model loaded successfully!")
except Exception as e:
# Provide a helpful error message if loading fails
print(f"Error loading model: {e}")
# You can display this error in the Gradio UI as well if you want
# For now, we'll just let the Space crash with a clear log message.
raise e
# --- 2. Define the Prediction Function ---
def classify_gender(text: str) -> dict:
"""
Takes a string of text and returns the model's predictions
in a format that Gradio's Label component can display.
"""
if not text or not text.strip():
# Handle empty or whitespace-only input gracefully
return None
# The pipeline will run the text through the model
# top_k=3 ensures we get scores for all 3 labels
predictions = classifier(text, top_k=3)
# Format the predictions into a {label: confidence} dictionary for the gr.Label component
formatted_predictions = {p['label']: p['score'] for p in predictions}
return formatted_predictions
# --- 3. Create the Gradio Interface ---
DESCRIPTION = """
## Bilingual Gender Classifier
This is a demo for the private model `breadlicker45/bilingual-large-gender-v4-test`.
Enter a sentence in **English or Spanish**, and the model will predict whether the text has a male, female, or neutral connotation.
**Disclaimer:** This model, like any AI, can have biases and may not always be accurate. It is intended for demonstration purposes.
"""
ARTICLE = """
<div style='text-align: center;'>
<p>Model based on <a href='https://huggingface.co/xlm-roberta-large' target='_blank'>XLM-RoBERTa-Large</a>, fine-tuned for gender classification.</p>
<p>This is a private model, but you can find more public models on the <a href='https://huggingface.co/models' target='_blank'>Hugging Face Hub</a>.</p>
</div>
"""
# Define some examples for users to try
examples = [
["He went to the store to buy a new hammer."],
["La doctora le recetó un medicamento a su paciente."],
["The development team will present their findings tomorrow."],
["My sister is the best programmer I know."],
["El futbolista marcó el gol decisivo."],
["The flight crew is preparing for takeoff."]
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
lines=5,
label="Input Text",
placeholder="Enter a sentence in English or Spanish here..."
)
submit_btn = gr.Button("Classify Text", variant="primary")
with gr.Column(scale=1):
output_label = gr.Label(
num_top_classes=3,
label="Classification Results"
)
gr.Examples(
examples=examples,
inputs=text_input,
outputs=output_label,
fn=classify_gender,
cache_examples=True
)
gr.Markdown(ARTICLE)
submit_btn.click(
fn=classify_gender,
inputs=text_input,
outputs=output_label,
api_name="classify" # You can add an API name for programmatic access
)
# --- 4. Launch the App ---
if __name__ == "__main__":
demo.launch() |