|
|
import gradio as gr |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if HF_TOKEN is None: |
|
|
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret in your Space settings.") |
|
|
|
|
|
MODEL_ID = "breadlicker45/multilingual-bert-gender-classification" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
print(f"Loading model: {MODEL_ID}...") |
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True) |
|
|
|
|
|
|
|
|
model.to(device) |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
def classify_gender(text: str) -> dict: |
|
|
if not text or not text.strip(): |
|
|
return None |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
scores = probabilities.squeeze().tolist() |
|
|
|
|
|
results = {} |
|
|
for i, score in enumerate(scores): |
|
|
label_name = model.config.id2label[i] |
|
|
results[label_name] = score |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DESCRIPTION = """ |
|
|
## Bilingual Gender Classifier |
|
|
This is a demo for the model `breadlicker45/multilingual-bert-gender-classification`. |
|
|
Enter a sentence and the model will predict whether the text has a male, female, or non-binary. |
|
|
**Disclaimer:** This model, like any AI, can have biases and may not always be accurate. |
|
|
""" |
|
|
|
|
|
ARTICLE = """ |
|
|
""" |
|
|
|
|
|
examples = [ |
|
|
["this is a test."] |
|
|
] |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
text_input = gr.Textbox( |
|
|
lines=5, |
|
|
label="Input Text", |
|
|
placeholder="Enter a sentence in here..." |
|
|
) |
|
|
submit_btn = gr.Button("Classify Text", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_label = gr.Label( |
|
|
num_top_classes=3, |
|
|
label="Classification Results" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=text_input, |
|
|
outputs=output_label, |
|
|
fn=classify_gender, |
|
|
cache_examples=True |
|
|
) |
|
|
|
|
|
gr.Markdown(ARTICLE) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=classify_gender, |
|
|
inputs=text_input, |
|
|
outputs=output_label, |
|
|
api_name="classify" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |