File size: 3,318 Bytes
8a2fc94
 
 
4e27d7a
 
 
 
 
 
 
 
 
 
8a2fc94
 
 
 
4e27d7a
 
 
8a2fc94
 
4e27d7a
 
 
8a2fc94
 
 
 
4e27d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a2fc94
4e27d7a
8a2fc94
4e27d7a
8a2fc94
4e27d7a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pandas as pd
import numpy as np
import joblib
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

app = FastAPI()

class InputData(BaseModel):
    wavelengths: list
    SCAN1: list
    SCAN2: list
    # Add more fields as needed

model = joblib.load('random_forest_model.joblib')
scaler = joblib.load('scaler.joblib')

def preprocess_input_data(input_data):
    template_data = {'wavelengths': input_data['wavelengths']}
    for i in range(1, 16):  # Assuming you have 15 scans
        template_data[f'SCAN{i}'] = 0
    template_df = pd.DataFrame(template_data)
    for key, value in input_data.items():
        if key in template_df.columns:
            template_df[key] = value
    template_df_cleaned = template_df.dropna(axis=1)
    preprocessed_df = scaler.transform(template_df_cleaned)
    return preprocessed_df

@app.post("/classify")
def classify(input_data: InputData):
    try:
        preprocessed_input = preprocess_input_data(input_data.dict())
        predictions = model.predict(preprocessed_input)
        most_common_class = np.argmax(np.bincount(predictions))
        
        if most_common_class == 0:
            return {"result": 'Resistant'}
        elif most_common_class == 1:
            return {"result": 'Medium'}
        elif most_common_class == 2:
            return {"result": 'Susceptible'}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


# End of Code






# import streamlit as st
# import pandas as pd
# import numpy as np
# import joblib
# from collections import Counter

# model = joblib.load('random_forest_model.joblib')
# scaler = joblib.load('scaler.joblib')

# def preprocess_uploaded_file(uploaded_df, scaler, num_scans=15):
#     uploaded_df.columns = ['wavelengths'] + [f'SCAN{i}' for i in range(1, uploaded_df.shape[1])]
#     template_data = {'wavelengths': uploaded_df['wavelengths']}
#     for i in range(1, num_scans + 1):
#         template_data[f'SCAN{i}'] = 0
#     template_df = pd.DataFrame(template_data)
#     for column in uploaded_df.columns:
#         if column in template_df.columns:
#             template_df[column] = uploaded_df[column]
#     template_df_cleaned = template_df.dropna(axis=1)
#     preprocessed_df = scaler.transform(template_df_cleaned)
#     return preprocessed_df

# st.image('logo.png', caption=None, width=None, use_column_width=None, clamp=False, channels="RGB", output_format="auto")

# st.markdown("<h3 style='text-align: center; color: grey;'>Hyperspectral Based System For Identification Of Common Bean Genotypes Resistant To Foliar Diseases</h2>", unsafe_allow_html=True)

# uploaded_file = st.file_uploader('Choose a CSV file', type='csv')

# if uploaded_file is not None:
#     input_df = pd.read_csv(uploaded_file)
#     preprocessed_input = preprocess_uploaded_file(input_df, scaler)
#     predictions = model.predict(preprocessed_input)
#     most_common_class = Counter(predictions).most_common(1)[0][0]
#     if most_common_class == 'Resistant':
#         st.write('The Plant is resistant to foliar diseases.')
#     elif most_common_class == 'Medium':
#         st.write('The Plant shows medium resistance to foliar diseases.')
#     elif most_common_class == 'Susceptible':
#         st.write('The Plant is susceptible to foliar diseases.')