File size: 3,318 Bytes
8a2fc94 4e27d7a 8a2fc94 4e27d7a 8a2fc94 4e27d7a 8a2fc94 4e27d7a 8a2fc94 4e27d7a 8a2fc94 4e27d7a 8a2fc94 4e27d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import pandas as pd
import numpy as np
import joblib
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI()
class InputData(BaseModel):
wavelengths: list
SCAN1: list
SCAN2: list
# Add more fields as needed
model = joblib.load('random_forest_model.joblib')
scaler = joblib.load('scaler.joblib')
def preprocess_input_data(input_data):
template_data = {'wavelengths': input_data['wavelengths']}
for i in range(1, 16): # Assuming you have 15 scans
template_data[f'SCAN{i}'] = 0
template_df = pd.DataFrame(template_data)
for key, value in input_data.items():
if key in template_df.columns:
template_df[key] = value
template_df_cleaned = template_df.dropna(axis=1)
preprocessed_df = scaler.transform(template_df_cleaned)
return preprocessed_df
@app.post("/classify")
def classify(input_data: InputData):
try:
preprocessed_input = preprocess_input_data(input_data.dict())
predictions = model.predict(preprocessed_input)
most_common_class = np.argmax(np.bincount(predictions))
if most_common_class == 0:
return {"result": 'Resistant'}
elif most_common_class == 1:
return {"result": 'Medium'}
elif most_common_class == 2:
return {"result": 'Susceptible'}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# End of Code
# import streamlit as st
# import pandas as pd
# import numpy as np
# import joblib
# from collections import Counter
# model = joblib.load('random_forest_model.joblib')
# scaler = joblib.load('scaler.joblib')
# def preprocess_uploaded_file(uploaded_df, scaler, num_scans=15):
# uploaded_df.columns = ['wavelengths'] + [f'SCAN{i}' for i in range(1, uploaded_df.shape[1])]
# template_data = {'wavelengths': uploaded_df['wavelengths']}
# for i in range(1, num_scans + 1):
# template_data[f'SCAN{i}'] = 0
# template_df = pd.DataFrame(template_data)
# for column in uploaded_df.columns:
# if column in template_df.columns:
# template_df[column] = uploaded_df[column]
# template_df_cleaned = template_df.dropna(axis=1)
# preprocessed_df = scaler.transform(template_df_cleaned)
# return preprocessed_df
# st.image('logo.png', caption=None, width=None, use_column_width=None, clamp=False, channels="RGB", output_format="auto")
# st.markdown("<h3 style='text-align: center; color: grey;'>Hyperspectral Based System For Identification Of Common Bean Genotypes Resistant To Foliar Diseases</h2>", unsafe_allow_html=True)
# uploaded_file = st.file_uploader('Choose a CSV file', type='csv')
# if uploaded_file is not None:
# input_df = pd.read_csv(uploaded_file)
# preprocessed_input = preprocess_uploaded_file(input_df, scaler)
# predictions = model.predict(preprocessed_input)
# most_common_class = Counter(predictions).most_common(1)[0][0]
# if most_common_class == 'Resistant':
# st.write('The Plant is resistant to foliar diseases.')
# elif most_common_class == 'Medium':
# st.write('The Plant shows medium resistance to foliar diseases.')
# elif most_common_class == 'Susceptible':
# st.write('The Plant is susceptible to foliar diseases.')
|