import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import gradio as gr from groq import Groq # Set up the API key (replace with your actual Groq API key) os.environ["GROQ_API_KEY"] = "My_API_key" # Groq client setup client = Groq(api_key=os.environ.get("My_API_key")) # Function to generate traffic optimization strategies from Groq def generate_traffic_optimization(data: str): response = client.chat.completions.create( messages=[ { "role": "user", "content": f"Generate a detailed traffic flow optimization strategy for the following data. Include peak hours, vehicle type distributions, and actionable suggestions to improve flow: {data}", } ], model="llama3-8b-8192", # You can choose a different model if necessary ) return response.choices[0].message.content # Visualization function to generate the traffic flow chart def generate_traffic_chart(df_filtered): # Summarize traffic data by aggregating vehicle counts for each record df_filtered['Total'] = df_filtered['CarCount'] + df_filtered['BikeCount'] + df_filtered['BusCount'] + df_filtered['TruckCount'] # Select relevant columns for the optimization request (you can adjust this as needed) traffic_data_summary = df_filtered[['Time', 'Total', 'Traffic Situation', 'CarCount', 'BikeCount', 'BusCount', 'TruckCount']] # Calculate some basic statistics avg_car_count = df_filtered['CarCount'].mean() avg_bike_count = df_filtered['BikeCount'].mean() avg_bus_count = df_filtered['BusCount'].mean() avg_truck_count = df_filtered['TruckCount'].mean() peak_traffic_time = df_filtered.loc[df_filtered['Total'].idxmax()]['Time'] # Generate summary for Groq API input summary_str = traffic_data_summary.head(10).to_string(index=False) # Get the optimization strategy from Groq optimization_strategy = generate_traffic_optimization(summary_str) # Visualization of traffic flow data time_labels = df_filtered['Time'].head(10) car_counts = df_filtered['CarCount'].head(10) bike_counts = df_filtered['BikeCount'].head(10) bus_counts = df_filtered['BusCount'].head(10) truck_counts = df_filtered['TruckCount'].head(10) # Create the stacked bar chart for vehicle counts fig, ax = plt.subplots(figsize=(10, 6)) ax.bar(time_labels, car_counts, label='Cars', color='blue') ax.bar(time_labels, bike_counts, bottom=car_counts, label='Bikes', color='green') ax.bar(time_labels, bus_counts, bottom=np.array(car_counts) + np.array(bike_counts), label='Buses', color='red') ax.bar(time_labels, truck_counts, bottom=np.array(car_counts) + np.array(bike_counts) + np.array(bus_counts), label='Trucks', color='yellow') # Customize chart ax.set_xlabel('Time') ax.set_ylabel('Vehicle Count') ax.set_title('Traffic Flow by Vehicle Type') ax.legend() # Save the plot as a file plt.xticks(rotation=45) plt.tight_layout() chart_path = "/tmp/traffic_chart.png" plt.savefig(chart_path) plt.close() # Return additional insights insights = f""" Average Car Count: {avg_car_count:.2f} Average Bike Count: {avg_bike_count:.2f} Average Bus Count: {avg_bus_count:.2f} Average Truck Count: {avg_truck_count:.2f} Peak Traffic Time: {peak_traffic_time} """ return optimization_strategy + "\n\n" + insights, chart_path # Function to process the uploaded file and run traffic optimization def process_traffic_file(file): # Load the dataset df = pd.read_csv(file.name) # Optionally, you may filter data for specific days or time intervals # For example, let's filter the data for a specific day: df_filtered = df[df['Day of the week'] == 'Monday'] # Generate traffic chart and optimization strategy optimization_strategy, chart_path = generate_traffic_chart(df_filtered) return optimization_strategy, chart_path # Gradio interface iface = gr.Interface( fn=process_traffic_file, inputs=gr.File(label="Upload CSV with Traffic Data"), outputs=[gr.Textbox(label="Optimization Strategy and Insights"), gr.Image(label="Traffic Flow Chart")], title="Traffic Flow Optimization", description="Upload a CSV file with traffic data, and the app will generate traffic optimization strategies, provide insights, and visualize the traffic flow." ) # Launch the app iface.launch()