Milestone2_Deployment / prediction.py
notbeekay's picture
Upload 9 files
047a1eb verified
import streamlit as st
import pandas as pd
import pickle
# Load the pre-trained model
with open('model_svr.pkl', 'rb') as file_1:
model_svr = pickle.load(file_1)
def run():
# Create title
st.title('IMDb Movie Score Prediction')
# Create subheader
st.subheader('Calculate IMDb Score of Movies')
# Create a form for input
with st.form('form_movie_prediction'):
# Text inputs
name = st.text_input('Movie Name: ', value = '')
director = st.text_input('Director: ', value = '')
writer = st.text_input('Writer: ', value = '')
star = st.text_input('Star: ', value = '')
country = st.text_input('Country: ', value = '')
company = st.text_input('Production Company: ', value ='')
released = st.text_input('Date Released: ', value = '')
# Number inputs
year = st.number_input('Release Year: ', value=2022, min_value=1900, max_value=2100)
budget = st.number_input('Budget ($): ', value=500000000, min_value=0)
gross = st.number_input('Gross Revenue ($): ', value=958000000, min_value=0)
runtime = st.number_input('Runtime (minutes): ', value=189, min_value=1)
votes = st.number_input('Votes: ', value=500000, min_value=0)
# Categorical inputs
rating = st.selectbox('Rating: ', ('G', 'PG', 'PG-13', 'R', 'NC-17'), index=3)
genre = st.selectbox('Genre: ', ('Action', 'Adventure', 'Comedy', 'Drama', 'History', 'Sci-Fi', 'Thriller'), index=4)
# Submit button
submitted = st.form_submit_button('Predict IMDb Score')
# Prepare the data for prediction
data_inf = {
'name': name,
'rating': rating,
'genre': genre,
'year': year,
'released': released,
'votes': votes,
'director': director,
'writer': writer,
'star': star,
'country': country,
'budget': budget,
'gross': gross,
'company': company,
'runtime': runtime
}
data_inf = pd.DataFrame([data_inf])
st.dataframe(data_inf)
if submitted:
# Predict IMDb score for Oppenheimer using the SVR model
prediction = model_svr.predict(data_inf)
st.write('## Predicted IMDb Score: ', str(round(prediction[0], 2)))
if __name__ == '__main__':
run()