| import streamlit as st |
| from simpletransformers.classification import MultiLabelClassificationModel |
| import torch |
|
|
| |
| def predict(model, text): |
| raw_outputs, _ = model.predict([text]) |
| return raw_outputs |
|
|
| |
| def main(): |
| st.title("Dravidian-English Code Mixed TextSentiment Prediction App") |
|
|
| |
| selected_language = st.selectbox("Select Language Model", ["Kannada", "Malayalam", "Tamil"]) |
|
|
| |
| model_paths = { |
| "Kannada": "Diya-Roshan/xlm-code-mixed-kannada-sentiment-classifier, |
| "Malayalam": "MalModel1", |
| "Tamil": "TamModel1", |
| } |
| |
| if selected_language in model_paths: |
| model_path = model_paths[selected_language] |
| model = MultiLabelClassificationModel('xlm', model_path, use_cuda=False) |
| |
| # User input for text |
| text_input = st.text_area("Enter text for prediction", "") |
| |
| # Make predictions when the user clicks the button |
| if st.button("Predict"): |
| if text_input: |
| predictions = predict(model, text_input) |
| |
| # Display the predictions |
| if predictions == [[1, 0, 0]]: |
| st.success('Positive Sentiment') |
| elif predictions == [[0, 1, 0]]: |
| st.error('Negative Sentiment') |
| else: |
| st.warning('Mixed Sentiment') |
| |
| if __name__ == "__main__": |
| main() |
| |