File size: 3,426 Bytes
39b12ed
65c4e1b
 
eb1d7b9
39b12ed
 
eb1d7b9
39b12ed
 
 
 
eb1d7b9
39b12ed
19ab706
eb1d7b9
 
 
39b12ed
 
65c4e1b
 
 
 
 
 
aa4b662
39b12ed
65c4e1b
eb1d7b9
39b12ed
 
 
 
 
 
eb1d7b9
65c4e1b
39b12ed
 
65c4e1b
eb1d7b9
 
 
 
 
 
 
 
65c4e1b
39b12ed
eb1d7b9
 
 
 
 
 
39b12ed
 
65c4e1b
39b12ed
 
 
19ab706
65dc5cb
 
39b12ed
 
 
 
 
 
9e7792e
39b12ed
 
 
 
 
 
 
 
 
 
65dc5cb
39b12ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb1d7b9
39b12ed
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
# BE EXPLICIT: Import the specific model class we need
from transformers import AutoTokenizer, XLMRobertaForSequenceClassification
import torch
import os

# --- 1. Setup: Load Model and Define Device ---

# Get the Hugging Face token from the Space's secrets
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
    raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret in your Space settings.")

MODEL_ID = "breadlicker45/bilingual-base-gender-v4.1-test"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print(f"Loading model: {MODEL_ID}...")
try:
    # Tokenizer can still be loaded automatically
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
    
    # THE FIX: Use the explicit class instead of AutoModelForSequenceClassification.
    # This ignores the problematic 'auto_map' in config.json and forces the
    # use of the standard XLM-RoBERTa architecture for sequence classification.
    model = XLMRobertaForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
    
    # Move the model to the selected device
    model.to(device)
    print("Model loaded successfully!")

except Exception as e:
    print(f"Error loading model: {e}")
    raise e

# --- 2. Define the Manual Prediction Function ---
# (This function is already correct and does not need changes)
def classify_gender(text: str) -> dict:
    if not text or not text.strip():
        return None

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        logits = model(**inputs).logits

    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    scores = probabilities.squeeze().tolist()
    
    results = {}
    for i, score in enumerate(scores):
        label_name = model.config.id2label[i]
        results[label_name] = score
        
    return results

# --- 3. Create the Gradio Interface ---
# (This part remains the same)

DESCRIPTION = """
## Bilingual Gender Classifier 
This is a demo for the model `bilingual-base-gender-v4.1-test`.
Enter a sentence and the model will predict whether the text has a male, female, or non-binary.
**Disclaimer:** This model, like any AI, can have biases and may not always be accurate.
"""

ARTICLE = """
"""

examples = [
    ["this is a test."] 
]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(DESCRIPTION)
    
    with gr.Row():
        with gr.Column(scale=2):
            text_input = gr.Textbox(
                lines=5, 
                label="Input Text", 
                placeholder="Enter a sentence in here..."
            )
            submit_btn = gr.Button("Classify Text", variant="primary")
        
        with gr.Column(scale=1):
            output_label = gr.Label(
                num_top_classes=3, 
                label="Classification Results"
            )
    
    gr.Examples(
        examples=examples,
        inputs=text_input,
        outputs=output_label,
        fn=classify_gender,
        cache_examples=True
    )
    
    gr.Markdown(ARTICLE)

    submit_btn.click(
        fn=classify_gender, 
        inputs=text_input, 
        outputs=output_label,
        api_name="classify"
    )

# --- 4. Launch the App ---
if __name__ == "__main__":
    demo.launch()