import streamlit as st import pandas as pd import joblib import plotly.express as px models = ["Linear Regression", "XGBoost", "Random Forests Regressor"] @st.cache_resource def load_dataset(df_path: str) -> pd.core.frame.DataFrame: """ Loads a `.csv` dataset for the streamlit app. Args: df_path (str): The path as to where the dataset is stored, and to be opened. Returns: Returns a pandas dataframe """ return pd.read_csv(df_path) @st.cache_resource def load_model(chosen_model: str) -> tuple: """ Loads the specified `.pkl` model and loads the scaler of the model. Args: chosen_model (str): The name of the model you wish to import for inference. Returns: tuple: The ML regressor model and the scaler """ model = None if chosen_model == models[0]: model = "./models/lr.pkl" elif chosen_model == models[1]: model = "./models/xgb_model" elif chosen_model == models[2]: model = "./models/rf.pkl" return joblib.load(model), joblib.load("./models/scaler.pkl") def get_input_data(df: pd.core.frame.DataFrame): """ Takes the input of user that will be used to predict the crop yield. Args: df (pd.core.frame.DataFrame): The dataframe to be used as a basis as to what the user can choose. Returns: A dataframe that contains a single instance based on the user input, and a tuple of the item and area. """ item_mapping = {category: code for code, category in enumerate(df['Item'].unique())} area_mapping = {category: code for code, category in enumerate(df['Area'].unique())} item = st.selectbox("Which plant do you wish to yield?", tuple(df['Item'].unique()), placeholder="Select a plant") area = st.selectbox("Which area do you want your crops to be planted?", tuple(df['Area'].unique()), placeholder="Select an area for your crop.") year = st.number_input("When do you wish to plant?", min_value=2025, max_value=2030, value=2025, placeholder="Choose a year") avg_rainfall = st.slider("What is the estimated average rainfall for that year?", min_value=df['average_rain_fall_mm_per_year'].min(), max_value=df['average_rain_fall_mm_per_year'].max()) pesticides = st.slider("How much pesticides do you intend to use? (in tonnes)", max_value=int(df['pesticides_tonnes'].max())) avg_temp = st.slider("How hot or cold is the area around you? (Average Temperature)", min_value=df['avg_temp'].min(), max_value=df['avg_temp'].max()) item_num = item_mapping[item] area_num = area_mapping[area] return pd.DataFrame({ "Area": [area_num], "Item": [item_num], "Year": [year], "average_rain_fall_mm_per_year": [avg_rainfall], "pesticides_tonnes": [pesticides], "avg_temp": [avg_temp] }), (item, area) def plot_map(countries: pd.Series) -> None: """ Plots the world map and highlights the countries that are frequent. Args: countries (pd.Series): A pandas series of the countries """ country_counts = countries.value_counts().reset_index() country_counts.columns = ['country', 'count'] # Create a choropleth map fig = px.choropleth( country_counts, locations='country', locationmode='country names', color='count', hover_name='country', color_continuous_scale='Blues', title='Countries in the Dataset' ) # Display in Streamlit st.plotly_chart(fig) def main(): st.title("Crop Yield Predictor") tab1, tab2, tab3 = st.tabs(["About the Data", "Data Viz", "Model Inference"]) df = load_dataset('./yield_df.csv') df = df.drop("Unnamed: 0", axis=1) with tab1: st.caption("The science of training machines to learn and produce models for future predictions is widely used, and not for nothing. Agriculture plays a critical role in the global economy. With the continuing expansion of the human population understanding worldwide crop yield is central to addressing food security challenges and reducing the impacts of climate change.") st.caption(" Crop yield prediction is an important agricultural problem. The Agricultural yield primarily depends on weather conditions (rain, temperature, etc), pesticides and accurate information about history of crop yield is an important thing for making decisions related to agricultural risk management and future predictions.") st.dataframe(df, height=300, width=900) col1, col2 = st.columns(2) col1.caption("**Area**: Geographic region or country where the crop is cultivated, serving as a key factor in yield variations due to climate, soil, and regional practices.") col2.caption("**Item**: Type of crop grown (e.g., wheat, rice), essential for modeling yield patterns and crop-specific responses to environmental factors.") col1.caption("**Year**: Time of harvest, helping analyze yield trends, seasonal patterns, and the impact of climate change over time.") col2.caption("**hg/ha_yield**: Crop yield per hectare (hectograms per hectare), the target variable indicating agricultural productivity for each crop and region.") col1.caption("**average_rain_fall_mm_per_year**: Annual rainfall measured in millimeters, a critical environmental factor influencing crop growth and yield.") col2.caption("**pesticides_tonnes**: Total pesticides applied (in tonnes), providing insight into pest control measures and their impact on crop yield.") col1.caption("**avg_temp**: Average annual temperature (°C), a vital climate factor affecting crop growth cycles, maturity rates, and overall yield.") st.divider() with tab2: plot_map(df['Area']) st.caption("The world map plot above showcases each country and its frequency in the dataset.") st.divider() x = st.selectbox("Choose X for plotting.", tuple(df.columns)) y = st.selectbox("Choose Y for plotting.", tuple(df.drop(x, axis=1).columns)) plot = st.selectbox("Select type of plot.", ("Scatter", "Bar", "Line")) if st.button("Plot X and Y!"): if plot == "Scatter": st.scatter_chart( data=df, x=x, y=y, size='hg/ha_yield' ) elif plot == "Bar": st.bar_chart( data=df, x=x, y=y ) elif plot == "Line": st.line_chart( data=df, x=x, y=y ) with tab3: input_data, (item, area) = get_input_data(df) selected_model = st.selectbox("Which model do you want to use?", tuple(models), placeholder="Select your model" ) with st.expander(f"Click to see performance of {selected_model}"): if selected_model == models[0]: st.image("./plots/lr_plot.png", caption="Linear Regression Plot") elif selected_model == models[1]: st.image("./plots/xgb_plot.png", caption="XG Boost Plot") elif selected_model == models[2]: st.image("./plots/rf_plot.png", caption="Random Forests Regressor Plot") if st.button("Predict yield!"): col1, col2 = st.columns(2) col1.metric("Area", area, border=True) col2.metric("Item", item, border=True) col3, col4, col5 = st.columns(3) col3.metric("Average Rainfall", input_data['average_rain_fall_mm_per_year'], border=True) col4.metric("Pestiscide Usage (Tonne)", input_data['pesticides_tonnes'], border=True) col5.metric("Average Temperature (Celcius)", input_data['avg_temp'], border=True) model, scaler = load_model(selected_model) input_scaled = scaler.transform(input_data) pred = model.predict(input_scaled) st.header(f"Predicted Crop Yield: **{int(pred[0])}**") if __name__ == "__main__": main()