NaseemTahir's picture
Update app.py
612b324 verified
import streamlit as st
import pandas as pd
import numpy as np
import faiss
import os
from groq import Groq
import matplotlib.pyplot as plt
# Initialize Groq API
GROQ_API_KEY = "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"
client = Groq(api_key=GROQ_API_KEY)
# Load the dataset
@st.cache
def load_data():
data = pd.read_csv("traffic.csv") # Replace with your dataset file name
return data
# Function to create FAISS index
def create_faiss_index(documents):
embeddings = np.random.rand(len(documents), 128).astype("float32") # Placeholder embeddings
index = faiss.IndexFlatL2(128) # FAISS index for 128-dimensional embeddings
index.add(embeddings)
return index, embeddings
# Load traffic data
traffic_data = load_data()
st.title("Dynamic Traffic Flow Optimization")
st.write("Optimize urban traffic flow using AI strategies.")
# Display dataset
st.write("### Traffic Dataset")
st.dataframe(traffic_data)
# Generate documents from dataset
documents = [
f"Junction: {row['Junction']}, Vehicles: {row['Vehicles']}, "
f"DateTime: {row['DateTime']}, ID: {row['ID']}"
for _, row in traffic_data.iterrows()
]
# Create FAISS index
index, embeddings = create_faiss_index(documents)
# User Query
query = st.text_input("Enter a traffic optimization query:")
if query:
# Generate fake embedding for query (replace with a real embedding model)
query_embedding = np.random.rand(1, 128).astype("float32")
# Retrieve relevant documents
distances, indices = index.search(query_embedding, 3)
retrieved_context = " ".join([documents[i] for i in indices[0]])
# Generate optimization strategy using Groq
response = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are an expert in traffic optimization."},
{"role": "user", "content": query},
{"role": "assistant", "content": f"Relevant context: {retrieved_context}"}
],
model="llama3-8b-8192",
)
generated_strategy = response.choices[0].message.content
# Display Results
st.write("### Relevant Context:")
st.write(retrieved_context)
st.write("### Optimization Strategy:")
st.write(generated_strategy)
# Visualize Vehicle Count
st.write("### Traffic Data Overview")
fig, ax = plt.subplots()
traffic_data.groupby('Junction')['Vehicles'].sum().plot(kind='bar', ax=ax, color='skyblue')
ax.set_title("Total Vehicle Count by Junction")
ax.set_ylabel("Number of Vehicles")
st.pyplot(fig)