carmel26 commited on
Commit
4e27d7a
·
1 Parent(s): bf7628c

Updating of the file and adding API logic on 30Dec2023 at 23h19

Browse files
Files changed (1) hide show
  1. app.py +76 -23
app.py CHANGED
@@ -1,39 +1,92 @@
1
- import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import joblib
5
- from collections import Counter
 
 
 
 
 
 
 
 
 
6
 
7
  model = joblib.load('random_forest_model.joblib')
8
  scaler = joblib.load('scaler.joblib')
9
 
10
- def preprocess_uploaded_file(uploaded_df, scaler, num_scans=15):
11
- uploaded_df.columns = ['wavelengths'] + [f'SCAN{i}' for i in range(1, uploaded_df.shape[1])]
12
- template_data = {'wavelengths': uploaded_df['wavelengths']}
13
- for i in range(1, num_scans + 1):
14
  template_data[f'SCAN{i}'] = 0
15
  template_df = pd.DataFrame(template_data)
16
- for column in uploaded_df.columns:
17
- if column in template_df.columns:
18
- template_df[column] = uploaded_df[column]
19
  template_df_cleaned = template_df.dropna(axis=1)
20
  preprocessed_df = scaler.transform(template_df_cleaned)
21
  return preprocessed_df
22
 
23
- st.image('logo.png', caption=None, width=None, use_column_width=None, clamp=False, channels="RGB", output_format="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- st.markdown("<h3 style='text-align: center; color: grey;'>Hyperspectral Based System For Identification Of Common Bean Genotypes Resistant To Foliar Diseases</h2>", unsafe_allow_html=True)
26
 
27
- uploaded_file = st.file_uploader('Choose a CSV file', type='csv')
28
 
29
- if uploaded_file is not None:
30
- input_df = pd.read_csv(uploaded_file)
31
- preprocessed_input = preprocess_uploaded_file(input_df, scaler)
32
- predictions = model.predict(preprocessed_input)
33
- most_common_class = Counter(predictions).most_common(1)[0][0]
34
- if most_common_class == 'Resistant':
35
- st.write('The Plant is resistant to foliar diseases.')
36
- elif most_common_class == 'Medium':
37
- st.write('The Plant shows medium resistance to foliar diseases.')
38
- elif most_common_class == 'Susceptible':
39
- st.write('The Plant is susceptible to foliar diseases.')
 
 
1
  import pandas as pd
2
  import numpy as np
3
  import joblib
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel
6
+
7
+ app = FastAPI()
8
+
9
+ class InputData(BaseModel):
10
+ wavelengths: list
11
+ SCAN1: list
12
+ SCAN2: list
13
+ # Add more fields as needed
14
 
15
  model = joblib.load('random_forest_model.joblib')
16
  scaler = joblib.load('scaler.joblib')
17
 
18
+ def preprocess_input_data(input_data):
19
+ template_data = {'wavelengths': input_data['wavelengths']}
20
+ for i in range(1, 16): # Assuming you have 15 scans
 
21
  template_data[f'SCAN{i}'] = 0
22
  template_df = pd.DataFrame(template_data)
23
+ for key, value in input_data.items():
24
+ if key in template_df.columns:
25
+ template_df[key] = value
26
  template_df_cleaned = template_df.dropna(axis=1)
27
  preprocessed_df = scaler.transform(template_df_cleaned)
28
  return preprocessed_df
29
 
30
+ @app.post("/classify")
31
+ def classify(input_data: InputData):
32
+ try:
33
+ preprocessed_input = preprocess_input_data(input_data.dict())
34
+ predictions = model.predict(preprocessed_input)
35
+ most_common_class = np.argmax(np.bincount(predictions))
36
+
37
+ if most_common_class == 0:
38
+ return {"result": 'Resistant'}
39
+ elif most_common_class == 1:
40
+ return {"result": 'Medium'}
41
+ elif most_common_class == 2:
42
+ return {"result": 'Susceptible'}
43
+ except Exception as e:
44
+ raise HTTPException(status_code=500, detail=str(e))
45
+
46
+
47
+ # End of Code
48
+
49
+
50
+
51
+
52
+
53
+
54
+ # import streamlit as st
55
+ # import pandas as pd
56
+ # import numpy as np
57
+ # import joblib
58
+ # from collections import Counter
59
+
60
+ # model = joblib.load('random_forest_model.joblib')
61
+ # scaler = joblib.load('scaler.joblib')
62
+
63
+ # def preprocess_uploaded_file(uploaded_df, scaler, num_scans=15):
64
+ # uploaded_df.columns = ['wavelengths'] + [f'SCAN{i}' for i in range(1, uploaded_df.shape[1])]
65
+ # template_data = {'wavelengths': uploaded_df['wavelengths']}
66
+ # for i in range(1, num_scans + 1):
67
+ # template_data[f'SCAN{i}'] = 0
68
+ # template_df = pd.DataFrame(template_data)
69
+ # for column in uploaded_df.columns:
70
+ # if column in template_df.columns:
71
+ # template_df[column] = uploaded_df[column]
72
+ # template_df_cleaned = template_df.dropna(axis=1)
73
+ # preprocessed_df = scaler.transform(template_df_cleaned)
74
+ # return preprocessed_df
75
+
76
+ # st.image('logo.png', caption=None, width=None, use_column_width=None, clamp=False, channels="RGB", output_format="auto")
77
 
78
+ # st.markdown("<h3 style='text-align: center; color: grey;'>Hyperspectral Based System For Identification Of Common Bean Genotypes Resistant To Foliar Diseases</h2>", unsafe_allow_html=True)
79
 
80
+ # uploaded_file = st.file_uploader('Choose a CSV file', type='csv')
81
 
82
+ # if uploaded_file is not None:
83
+ # input_df = pd.read_csv(uploaded_file)
84
+ # preprocessed_input = preprocess_uploaded_file(input_df, scaler)
85
+ # predictions = model.predict(preprocessed_input)
86
+ # most_common_class = Counter(predictions).most_common(1)[0][0]
87
+ # if most_common_class == 'Resistant':
88
+ # st.write('The Plant is resistant to foliar diseases.')
89
+ # elif most_common_class == 'Medium':
90
+ # st.write('The Plant shows medium resistance to foliar diseases.')
91
+ # elif most_common_class == 'Susceptible':
92
+ # st.write('The Plant is susceptible to foliar diseases.')