| import gradio as gr |
| import torch |
| import re |
| from transformers import BertTokenizer, BertForSequenceClassification |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| model_name = "./model" |
| tokenizer = BertTokenizer.from_pretrained(model_name) |
| model = BertForSequenceClassification.from_pretrained(model_name) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| model.eval() |
|
|
| |
| |
| |
|
|
| |
| MBTI_CLASSES = [ |
| "ISTJ", "ISFJ", "INFJ", "INTJ", |
| "ISTP", "ISFP", "INFP", "INTP", |
| "ESTP", "ESFP", "ENFP", "ENTP", |
| "ESTJ", "ESFJ", "ENFJ", "ENTJ" |
| ] |
|
|
| |
| def preprocess_text(text): |
| text = text.lower() |
| text = re.sub(r"http\S+|www.\S+", "", text) |
| text = re.sub(r"[^a-zA-Z\s]", "", text) |
| |
| |
| |
|
|
| |
| def predict_mbti(text): |
| cleaned = preprocess_text(text) |
| inputs = tokenizer( |
| cleaned, |
| max_length=512, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ).to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| pred_idx = torch.argmax(outputs.logits, dim=1).item() |
| return MBTI_CLASSES[pred_idx] |
|
|
| |
| interface = gr.Interface( |
| fn=predict_mbti, |
| inputs=gr.Textbox(lines=12, label="Enter Combined Answers (Q1 A1 Q2 A2 ...)"), |
| outputs=gr.Textbox(label="Predicted MBTI Type"), |
| title="MBTI Personality Predictor (BERT)", |
| description="Paste your combined answers to get your MBTI personality type. Powered by Sid26Roy/mbti" |
| ) |
|
|
| if __name__ == "__main__": |
| interface.launch() |
|
|