sialnoman318's picture
Update app.py
132feaf verified
import os
import json
import gradio as gr
import pandas as pd
import numpy as np
import faiss
from sklearn.preprocessing import StandardScaler
from groq import Groq
import plotly.express as px
# Set up Groq client
client = Groq(api_key="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")
# Initialize FAISS index and scaler
scaler = StandardScaler()
index = None
data_columns = []
project_ids = []
# Function to load historical project data
def load_data(file):
global index, scaler, data_columns, project_ids
try:
# Load the CSV file
data = pd.read_csv(file.name)
data_columns = data.columns.tolist()
project_ids = data.iloc[:, 0].tolist() # First column assumed as project IDs
# Process numerical data
numerical_data = data.select_dtypes(include=[np.number])
if numerical_data.empty:
return "Error: No numeric columns found in the dataset."
scaled_data = scaler.fit_transform(numerical_data.values)
faiss_data = scaled_data.astype("float32")
dimension = faiss_data.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(faiss_data)
return f"Data loaded successfully! Columns: {', '.join(data_columns)}"
except Exception as e:
return f"Error loading data: {e}"
# Function to generate construction schedule
def generate_schedule(input_parameters):
global index, scaler, project_ids, data_columns
try:
# Parse input parameters
parameters = list(map(float, input_parameters.split(",")))
input_features = np.array([parameters]).astype("float32")
scaled_features = scaler.transform(input_features)
# Query FAISS for similar projects
k = 5 # Number of similar projects to retrieve
distances, indices = index.search(scaled_features, k)
similar_projects = [{"project_id": project_ids[i]} for i in indices[0]]
# Construct prompt for Groq API
prompt = f"""
Based on these similar projects:
{similar_projects}
Generate an optimized construction schedule for a project with the following parameters:
Duration: {parameters[0]} days, Milestones: {parameters[1]}, Delays: {parameters[2]} days.
Output a JSON object for a Gantt chart, including tasks, start and end dates, and dependencies.
"""
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
)
generated_schedule = response.choices[0].message.content
# Parse the Groq API response
schedule_dict = json.loads(generated_schedule)
schedule_data = pd.DataFrame(schedule_dict)
# Generate Gantt chart
fig = px.timeline(
schedule_data,
x_start="start_date",
x_end="end_date",
y="task",
title="Optimized Construction Schedule",
labels={"task": "Task", "start_date": "Start Date", "end_date": "End Date"},
)
fig.update_yaxes(categoryorder="total ascending")
return "Schedule generated successfully!", fig
except Exception as e:
return f"Error generating schedule: {e}", None
# Gradio app definition
def main():
with gr.Blocks() as app:
gr.Markdown("# Intelligent Construction Scheduling Assistant")
with gr.Tab("Upload Data"):
data_file = gr.File(label="Upload Historical Project Data (CSV)")
upload_button = gr.Button("Upload")
upload_output = gr.Textbox(label="Output", interactive=False)
upload_button.click(load_data, inputs=data_file, outputs=upload_output)
with gr.Tab("Generate Schedule"):
input_params = gr.Textbox(
label="Input Parameters (comma-separated, e.g., '120, 10, 5')",
placeholder="Enter duration, milestones, and delays (in days)",
)
generate_button = gr.Button("Generate Schedule")
output_text = gr.Textbox(label="Output", interactive=False)
output_plot = gr.Plot(label="Schedule Gantt Chart")
generate_button.click(
generate_schedule, inputs=input_params, outputs=[output_text, output_plot]
)
return app
# Run app
if __name__ == "__main__":
app = main()
app.launch()