BasselAhmed's picture
Update app.py
3048da6 verified
import streamlit as st
import pandas as pd
from simpletransformers.classification import ClassificationModel
import torch
import random
import os
import numpy as np
import subprocess
import requests
random.seed(4)
np.random.seed(4)
torch.manual_seed(4)
np.random.seed(4)
#def start_fastapi():
# Start the FastAPI server
#start_fastapi()
# Set the Streamlit app title
st.title("Molecule Toxicity Predictions")
# Set the model path
path = 'ToxicityPrediction/Models/transformers/checkpoint-149-epoch-1'
# Load the model from the stage
#loaded_model = ClassificationModel('roberta', path, use_cuda = False)
#rob_chem_model = ClassificationModel('roberta', 'seyonec/SMILES_tokenized_PubChem_shard00_160k',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
# Predict based on the input
rob_chem_model = ClassificationModel('roberta', 'BasselAhmed/RobertaChemClinToxTuned',use_cuda=False ,args={'evaluate_each_epoch':True , 'evaluate_during_training_verbose':True, 'seed':4})
rob_chem_model.model.eval()
#target_name= st.text_input('Enter a SMILES string:')
target_name = st.text_area("Enter smiles (one per line):", "")
target_name_list = target_name.splitlines()
target_name_list = [x.strip() for x in target_name_list]
predict_toxicity = st.button('Predict Toxicity')
if predict_toxicity:
predictions, raw_outputs = rob_chem_model.predict(target_name_list)
df_pred = pd.DataFrame({'Smiles':target_name_list,'Predictions': predictions})
st.dataframe(df_pred)