mbti_classify / app.py
ali
update app.py3
77fa791
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import os
# Ensure compatibility with protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Path to your model directory
model_path = "./mbti_model_2"
# Load model and tokenizer with label mappings
@st.cache_resource
def load_pipeline_and_mapping():
try:
# Load model configuration to get label-to-MBTI mapping
config = AutoConfig.from_pretrained(model_path)
label_to_mbti = config.id2label if hasattr(config, "id2label") else {}
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
return pipe, label_to_mbti
except Exception as e:
st.error(f"Error loading the model: {e}")
return None, {}
pipe, label_to_mbti = load_pipeline_and_mapping()
# Streamlit UI
st.title("MBTI Personality Prediction")
st.write("Enter text below to classify the MBTI personality type:")
# Input text box
user_input = st.text_area("Input Text", placeholder="Type something here...", height=200)
# Predict button
if st.button("Predict"):
if not pipe:
st.error("The model failed to load. Please check the setup.")
elif user_input.strip():
# Generate predictions
predictions = pipe(user_input)
st.subheader("Predictions:")
for pred in predictions:
mbti_type = label_to_mbti.get(pred["label"], "Unknown")
st.write(f"**MBTI Type:** {mbti_type}, **Confidence:** {pred['score']:.4f}")
else:
st.warning("Please enter some text before clicking 'Predict'.")