Vishalpy12 commited on
Commit
6f23def
Β·
verified Β·
1 Parent(s): 6fac19f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +170 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,172 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ from pymongo import MongoClient
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from sklearn.ensemble import RandomForestRegressor
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+ from langchain_groq import ChatGroq
10
+ from langchain.chains import LLMChain
11
+ from langchain.prompts import PromptTemplate
12
+ from io import BytesIO
13
+ from streamlit_autorefresh import st_autorefresh
14
+
15
+
16
+
17
+
18
+
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ mongo_uri = os.getenv("MONGO_URI")
23
+ db_name = os.getenv("DB_NAME")
24
+ collection_name = os.getenv("COLLECTION_NAME")
25
+ groq_api_key = os.getenv("GROQ_API_KEY")
26
+
27
+ # MongoDB connection
28
+ def connect_mongo():
29
+ client = MongoClient(mongo_uri)
30
+ db = client[db_name]
31
+ return db[collection_name]
32
+
33
+ # Fetch data from MongoDB
34
+ def get_data(collection):
35
+ df = pd.DataFrame(list(collection.find()))
36
+ if '_id' in df.columns:
37
+ df.drop(columns=['_id'], inplace=True)
38
+ return df
39
+
40
+ # Train the regression model
41
+ def train_model(X, y):
42
+ model = RandomForestRegressor(random_state=42)
43
+ model.fit(X, y)
44
+ return model
45
+
46
+ # Generate AI Report using LangChain + Groq
47
+ def generate_report(feature_impact, predicted_wqi, location, timestamp, selected):
48
+ param_info = "\n".join([f"- {param}: {selected[param]}" for param in feature_impact.keys()])
49
+
50
+ prompt = PromptTemplate.from_template(
51
+ """You are an expert environmental analyst.
52
+
53
+ The predicted Water Quality Index (WQI) is {predicted_wqi} at location \"{location}\" on {timestamp}.
54
+ The top contributing parameters with their actual sensor values are:
55
+ {param_info}
56
+
57
+ Write a report that includes:
58
+ 1. Likely causes for this WQI
59
+ 2. Why these parameters are significant
60
+ 3. Practical recommendations to improve WQI"""
61
+ )
62
+
63
+ llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.3-70b-versatile")
64
+ chain = LLMChain(llm=llm, prompt=prompt)
65
+
66
+ report = chain.run(
67
+ predicted_wqi=predicted_wqi,
68
+ location=location,
69
+ timestamp=timestamp,
70
+ param_info=param_info
71
+ )
72
+
73
+ report_cleaned = report.replace("**", "")
74
+ return report_cleaned
75
+
76
+ # Function to save report as TXT
77
+ def save_report_as_txt(text: str, filename: str) -> BytesIO:
78
+ buffer = BytesIO()
79
+ buffer.write(text.encode("utf-8"))
80
+ buffer.seek(0)
81
+ return buffer
82
+
83
+ # ---------- Streamlit UI ----------
84
+ st.set_page_config(page_title="Water Quality AI Analyzer", layout="wide")
85
+ st.title("πŸ’§ Water Quality Index Prediction & AI-Powered Report")
86
+
87
+ # Add auto-refresh using Streamlit timer
88
+ st_autorefresh(interval=60 * 1000, key="datarefresh")
89
+ st.markdown("⏰ Auto-refreshing every 60 seconds to fetch latest data...")
90
+
91
+ # Real-time data load from MongoDB
92
+ collection = connect_mongo()
93
+ df = get_data(collection)
94
+
95
+ if df.empty:
96
+ st.warning("No data found in MongoDB.")
97
+ st.stop()
98
+
99
+ st.success("βœ… Data successfully loaded from MongoDB")
100
+ st.dataframe(df.head())
101
+
102
+ # Define features and target
103
+ feature_cols = ['pH', 'turbidity', 'dissolved_oxygen', 'conductivity', 'temperature']
104
+ target_col = 'wqi'
105
+
106
+ if not all(col in df.columns for col in feature_cols + [target_col]):
107
+ st.error("❌ Required columns are missing from the dataset.")
108
+ st.stop()
109
+
110
+
111
+ # Train model
112
+ X = df[feature_cols]
113
+ y = df[target_col]
114
+ model = train_model(X, y)
115
+
116
+ # SHAP Explainer
117
+ explainer = shap.Explainer(model, X)
118
+ shap_values = explainer(X)
119
+
120
+ # Display SHAP feature importance with smaller size
121
+ st.subheader("πŸ“Š Feature Impact on WQI (SHAP Values)")
122
+ fig, ax = plt.subplots(figsize=(6, 4))
123
+ shap.summary_plot(shap_values, X, plot_type="bar", show=False)
124
+ st.pyplot(fig)
125
+
126
+ # Select record
127
+ st.subheader("πŸ” Select a Data Record for Detailed Analysis")
128
+ record_options = [f"{i}: {row.get('location', 'Unknown')} @ {row.get('timestamp', 'N/A')}" for i, row in df.iterrows()]
129
+ selected_label = st.selectbox("πŸ“‹ Select a Record by Location & Time", options=record_options)
130
+ selected_index = int(selected_label.split(":")[0])
131
+ selected = df.iloc[selected_index]
132
+
133
+ # Show selected record details
134
+ st.markdown(f"πŸ”’ Selected Index: `{selected_index}`")
135
+ st.markdown(f"πŸ“ Location: `{selected.get('location', 'N/A')}`")
136
+ st.markdown(f"⏰ Timestamp: `{selected.get('timestamp', 'N/A')}`")
137
+
138
+ input_data = selected[feature_cols].to_frame().T
139
+ predicted_wqi = model.predict(input_data)[0]
140
+
141
+ # Display chosen parameter values
142
+ st.markdown("### πŸ§ͺ Selected Sensor Parameters Used for WQI Prediction")
143
+ for param in feature_cols:
144
+ st.markdown(f"- **{param}**: `{selected[param]}`")
145
+
146
+ # SHAP for selected row
147
+ individual_shap = explainer(input_data)
148
+ impact = pd.Series(individual_shap.values[0], index=feature_cols).abs().sort_values(ascending=False)
149
+ top_impact = impact.head(3).to_dict()
150
+
151
+ # Show prediction
152
+ st.markdown(f"### πŸ€– Predicted WQI: `{predicted_wqi:.2f}`")
153
+
154
+ # Generate AI report and download
155
+ if st.button("πŸ“ Generate AI Report"):
156
+ location = selected.get("location", "Unknown")
157
+ timestamp = selected.get("timestamp", "Unknown")
158
+ report = generate_report(top_impact, predicted_wqi, location, timestamp, selected)
159
+
160
+ st.subheader("πŸ“ AI-Generated Water Quality Report")
161
+ st.markdown(report)
162
+
163
+ # Save as TXT
164
+ txt_file_name = f"water_quality_report_{location.replace(' ', '_')}_{timestamp[:10]}.txt"
165
+ report_txt = save_report_as_txt(report, txt_file_name)
166
 
167
+ st.download_button(
168
+ label="πŸ“„ Download Report (TXT)",
169
+ data=report_txt,
170
+ file_name=txt_file_name,
171
+ mime="text/plain"
172
+ )