Spaces:
Runtime error
Runtime error
| import torch | |
| import streamlit as st | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| # Load the pre-trained model and tokenizer | |
| model_path = "https://huggingface.co/jonaskoenig/topic_classification_04" # Replace with the path to your saved model | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| model = BertForSequenceClassification.from_pretrained(model_path) | |
| # Set up Streamlit app | |
| st.title("Topic Classification App") | |
| # User input for text | |
| user_input = st.text_area("Enter text for topic classification:", "") | |
| # Function to make predictions | |
| def predict_topic(text): | |
| inputs = tokenizer(text, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| return predicted_class | |
| # Make predictions and display result | |
| if st.button("Predict"): | |
| if user_input: | |
| st.info("Making Prediction...") | |
| prediction = predict_topic(user_input) | |
| st.success(f"Predicted Topic: {prediction}") | |
| else: | |
| st.warning("Please enter some text for prediction.") | |