breadlicker45's picture
Update app.py
e60e9e6 verified
import gradio as gr
# BE EXPLICIT: Import the specific model class we need
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
# Get the Hugging Face token from the Space's secrets
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret in your Space settings.")
MODEL_ID = "breadlicker45/multilingual-bert-gender-classification"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Loading model: {MODEL_ID}...")
try:
# Tokenizer can still be loaded automatically
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
# Move the model to the selected device
model.to(device)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise e
# --- 2. Define the Manual Prediction Function ---
# (This function is already correct and does not need changes)
def classify_gender(text: str) -> dict:
if not text or not text.strip():
return None
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
scores = probabilities.squeeze().tolist()
results = {}
for i, score in enumerate(scores):
label_name = model.config.id2label[i]
results[label_name] = score
return results
# --- 3. Create the Gradio Interface ---
# (This part remains the same)
DESCRIPTION = """
## Bilingual Gender Classifier
This is a demo for the model `breadlicker45/multilingual-bert-gender-classification`.
Enter a sentence and the model will predict whether the text has a male, female, or non-binary.
**Disclaimer:** This model, like any AI, can have biases and may not always be accurate.
"""
ARTICLE = """
"""
examples = [
["this is a test."]
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
lines=5,
label="Input Text",
placeholder="Enter a sentence in here..."
)
submit_btn = gr.Button("Classify Text", variant="primary")
with gr.Column(scale=1):
output_label = gr.Label(
num_top_classes=3,
label="Classification Results"
)
gr.Examples(
examples=examples,
inputs=text_input,
outputs=output_label,
fn=classify_gender,
cache_examples=True
)
gr.Markdown(ARTICLE)
submit_btn.click(
fn=classify_gender,
inputs=text_input,
outputs=output_label,
api_name="classify"
)
# --- 4. Launch the App ---
if __name__ == "__main__":
demo.launch()