Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import faiss | |
| import os | |
| from groq import Groq | |
| import matplotlib.pyplot as plt | |
| # Initialize Groq API | |
| GROQ_API_KEY = "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" | |
| client = Groq(api_key=GROQ_API_KEY) | |
| # Load the dataset | |
| def load_data(): | |
| data = pd.read_csv("traffic.csv") # Replace with your dataset file name | |
| return data | |
| # Function to create FAISS index | |
| def create_faiss_index(documents): | |
| embeddings = np.random.rand(len(documents), 128).astype("float32") # Placeholder embeddings | |
| index = faiss.IndexFlatL2(128) # FAISS index for 128-dimensional embeddings | |
| index.add(embeddings) | |
| return index, embeddings | |
| # Load traffic data | |
| traffic_data = load_data() | |
| st.title("Dynamic Traffic Flow Optimization") | |
| st.write("Optimize urban traffic flow using AI strategies.") | |
| # Display dataset | |
| st.write("### Traffic Dataset") | |
| st.dataframe(traffic_data) | |
| # Generate documents from dataset | |
| documents = [ | |
| f"Junction: {row['Junction']}, Vehicles: {row['Vehicles']}, " | |
| f"DateTime: {row['DateTime']}, ID: {row['ID']}" | |
| for _, row in traffic_data.iterrows() | |
| ] | |
| # Create FAISS index | |
| index, embeddings = create_faiss_index(documents) | |
| # User Query | |
| query = st.text_input("Enter a traffic optimization query:") | |
| if query: | |
| # Generate fake embedding for query (replace with a real embedding model) | |
| query_embedding = np.random.rand(1, 128).astype("float32") | |
| # Retrieve relevant documents | |
| distances, indices = index.search(query_embedding, 3) | |
| retrieved_context = " ".join([documents[i] for i in indices[0]]) | |
| # Generate optimization strategy using Groq | |
| response = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": "You are an expert in traffic optimization."}, | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": f"Relevant context: {retrieved_context}"} | |
| ], | |
| model="llama3-8b-8192", | |
| ) | |
| generated_strategy = response.choices[0].message.content | |
| # Display Results | |
| st.write("### Relevant Context:") | |
| st.write(retrieved_context) | |
| st.write("### Optimization Strategy:") | |
| st.write(generated_strategy) | |
| # Visualize Vehicle Count | |
| st.write("### Traffic Data Overview") | |
| fig, ax = plt.subplots() | |
| traffic_data.groupby('Junction')['Vehicles'].sum().plot(kind='bar', ax=ax, color='skyblue') | |
| ax.set_title("Total Vehicle Count by Junction") | |
| ax.set_ylabel("Number of Vehicles") | |
| st.pyplot(fig) | |