Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +1571 -0
- requirements.txt +15 -0
app.py
ADDED
|
@@ -0,0 +1,1571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import pickle
|
| 5 |
+
import base64
|
| 6 |
+
from io import BytesIO, StringIO
|
| 7 |
+
import sys
|
| 8 |
+
import operator
|
| 9 |
+
from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple
|
| 10 |
+
import tempfile
|
| 11 |
+
import shutil
|
| 12 |
+
import plotly.io as pio
|
| 13 |
+
import io
|
| 14 |
+
import re
|
| 15 |
+
import json
|
| 16 |
+
import openai
|
| 17 |
+
# from fpdf import FPDF
|
| 18 |
+
import base64
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
from reportlab.lib.pagesizes import letter
|
| 21 |
+
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
|
| 22 |
+
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
| 23 |
+
from reportlab.lib.units import inch
|
| 24 |
+
from PIL import Image as PILImage
|
| 25 |
+
|
| 26 |
+
# Import LangChain and LangGraph components
|
| 27 |
+
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage
|
| 28 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 29 |
+
from langchain_openai import ChatOpenAI
|
| 30 |
+
from langchain_experimental.utilities import PythonREPL
|
| 31 |
+
from langgraph.prebuilt import ToolInvocation, ToolExecutor
|
| 32 |
+
from langchain_core.tools import tool
|
| 33 |
+
from langgraph.prebuilt import InjectedState
|
| 34 |
+
from langgraph.graph import StateGraph, END
|
| 35 |
+
from reportlab.platypus import PageBreak
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Initialize session state for AI provider settings
|
| 39 |
+
if 'ai_provider' not in st.session_state:
|
| 40 |
+
st.session_state.ai_provider = "openai"
|
| 41 |
+
|
| 42 |
+
if 'api_key' not in st.session_state:
|
| 43 |
+
st.session_state.api_key = ""
|
| 44 |
+
|
| 45 |
+
if 'selected_model' not in st.session_state:
|
| 46 |
+
st.session_state.selected_model = "gpt-4"
|
| 47 |
+
|
| 48 |
+
# Define model options for each provider
|
| 49 |
+
OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"]
|
| 50 |
+
GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"]
|
| 51 |
+
|
| 52 |
+
# Create temporary directory for file storage
|
| 53 |
+
if 'temp_dir' not in st.session_state:
|
| 54 |
+
st.session_state.temp_dir = tempfile.mkdtemp()
|
| 55 |
+
st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle")
|
| 56 |
+
os.makedirs(st.session_state.images_dir, exist_ok=True)
|
| 57 |
+
print(f"Created temporary directory: {st.session_state.temp_dir}")
|
| 58 |
+
print(f"Created images directory: {st.session_state.images_dir}")
|
| 59 |
+
|
| 60 |
+
# Define the system prompt
|
| 61 |
+
SYSTEM_PROMPT = """## Role
|
| 62 |
+
You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data.
|
| 63 |
+
|
| 64 |
+
## Capabilities
|
| 65 |
+
1. **Execute python code** using the `complete_python_task` tool.
|
| 66 |
+
|
| 67 |
+
## Goals
|
| 68 |
+
1. Understand the user's objectives clearly.
|
| 69 |
+
2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems.
|
| 70 |
+
3. Investigate if the goal is achievable by running Python code via the `python_code` field.
|
| 71 |
+
4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances.
|
| 72 |
+
|
| 73 |
+
## Code Guidelines
|
| 74 |
+
- **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data.
|
| 75 |
+
- **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed.
|
| 76 |
+
- **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise.
|
| 77 |
+
- **ONLY USE THE FOLLOWING LIBRARIES**:
|
| 78 |
+
- `pandas`
|
| 79 |
+
- `sklearn` (including all major ML models)
|
| 80 |
+
- `plotly`
|
| 81 |
+
- `numpy`
|
| 82 |
+
|
| 83 |
+
All these libraries are already imported for you.
|
| 84 |
+
|
| 85 |
+
## Machine Learning Guidelines
|
| 86 |
+
- For regression tasks:
|
| 87 |
+
- Linear Regression: `LinearRegression`
|
| 88 |
+
- Logistic Regression: `LogisticRegression`
|
| 89 |
+
- Ridge Regression: `Ridge`
|
| 90 |
+
- Lasso Regression: `Lasso`
|
| 91 |
+
- Random Forest Regression: `RandomForestRegressor`
|
| 92 |
+
|
| 93 |
+
- For classification tasks:
|
| 94 |
+
- Logistic Regression: `LogisticRegression`
|
| 95 |
+
- Decision Trees: `DecisionTreeClassifier`
|
| 96 |
+
- Random Forests: `RandomForestClassifier`
|
| 97 |
+
- Support Vector Machines: `SVC`
|
| 98 |
+
- K-Nearest Neighbors: `KNeighborsClassifier`
|
| 99 |
+
- Naive Bayes: `GaussianNB`
|
| 100 |
+
|
| 101 |
+
- For clustering:
|
| 102 |
+
- K-Means: `KMeans`
|
| 103 |
+
- DBSCAN: `DBSCAN`
|
| 104 |
+
|
| 105 |
+
- For dimensionality reduction:
|
| 106 |
+
- PCA: `PCA`
|
| 107 |
+
|
| 108 |
+
- Always preprocess data appropriately:
|
| 109 |
+
- Scale numerical features with `StandardScaler` or `MinMaxScaler`
|
| 110 |
+
- Encode categorical variables with `OneHotEncoder` when needed
|
| 111 |
+
- Handle missing values with `SimpleImputer`
|
| 112 |
+
|
| 113 |
+
- Always split data into training and testing sets using `train_test_split`
|
| 114 |
+
- Evaluate models using appropriate metrics:
|
| 115 |
+
- For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score`
|
| 116 |
+
- For classification: `accuracy_score`, `confusion_matrix`, `classification_report`
|
| 117 |
+
- For clustering: `silhouette_score`
|
| 118 |
+
|
| 119 |
+
- Consider using `cross_val_score` for more robust evaluation
|
| 120 |
+
- Visualize ML results with plotly when possible
|
| 121 |
+
|
| 122 |
+
## Plotting Guidelines
|
| 123 |
+
- Always use the `plotly` library for plotting.
|
| 124 |
+
- Store all plotly figures inside a `plotly_figures` list, they will be saved automatically.
|
| 125 |
+
- Do not try and show the plots inline with `fig.show()`.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# Define the State class
|
| 129 |
+
class AgentState(TypedDict):
|
| 130 |
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
| 131 |
+
input_data: Annotated[List[Dict], operator.add]
|
| 132 |
+
intermediate_outputs: Annotated[List[dict], operator.add]
|
| 133 |
+
current_variables: dict
|
| 134 |
+
output_image_paths: Annotated[List[str], operator.add]
|
| 135 |
+
|
| 136 |
+
# Initialize session state variables
|
| 137 |
+
if 'in_memory_datasets' not in st.session_state:
|
| 138 |
+
st.session_state.in_memory_datasets = {}
|
| 139 |
+
|
| 140 |
+
if 'persistent_vars' not in st.session_state:
|
| 141 |
+
st.session_state.persistent_vars = {}
|
| 142 |
+
|
| 143 |
+
if 'dataset_metadata_list' not in st.session_state:
|
| 144 |
+
st.session_state.dataset_metadata_list = []
|
| 145 |
+
|
| 146 |
+
if 'chat_history' not in st.session_state:
|
| 147 |
+
st.session_state.chat_history = []
|
| 148 |
+
|
| 149 |
+
if 'dashboard_plots' not in st.session_state:
|
| 150 |
+
st.session_state.dashboard_plots = [None, None, None, None]
|
| 151 |
+
|
| 152 |
+
if 'columns' not in st.session_state:
|
| 153 |
+
st.session_state.columns = ["No columns available"]
|
| 154 |
+
|
| 155 |
+
if 'custom_plots_to_save' not in st.session_state:
|
| 156 |
+
st.session_state.custom_plots_to_save = {}
|
| 157 |
+
|
| 158 |
+
# Set up the tools
|
| 159 |
+
repl = PythonREPL()
|
| 160 |
+
plotly_saving_code = """import pickle
|
| 161 |
+
|
| 162 |
+
import uuid
|
| 163 |
+
import os
|
| 164 |
+
for figure in plotly_figures:
|
| 165 |
+
pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle"
|
| 166 |
+
with open(pickle_filename, 'wb') as f:
|
| 167 |
+
pickle.dump(figure, f)
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
@tool
|
| 171 |
+
def complete_python_task(
|
| 172 |
+
graph_state: Annotated[dict, InjectedState],
|
| 173 |
+
thought: str,
|
| 174 |
+
python_code: str
|
| 175 |
+
) -> Tuple[str, dict]:
|
| 176 |
+
"""Execute Python code for data analysis and visualization."""
|
| 177 |
+
|
| 178 |
+
current_variables = graph_state.get("current_variables", {})
|
| 179 |
+
|
| 180 |
+
# Load datasets from in-memory storage
|
| 181 |
+
for input_dataset in graph_state.get("input_data", []):
|
| 182 |
+
var_name = input_dataset.get("variable_name")
|
| 183 |
+
if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets:
|
| 184 |
+
print(f"Loading {var_name} from in-memory storage")
|
| 185 |
+
current_variables[var_name] = st.session_state.in_memory_datasets[var_name]
|
| 186 |
+
current_image_pickle_files = os.listdir(st.session_state.images_dir)
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
# Capture stdout
|
| 190 |
+
old_stdout = sys.stdout
|
| 191 |
+
sys.stdout = StringIO()
|
| 192 |
+
|
| 193 |
+
# Execute the code and capture the result
|
| 194 |
+
exec_globals = globals().copy()
|
| 195 |
+
exec_globals.update(st.session_state.persistent_vars)
|
| 196 |
+
exec_globals.update(current_variables)
|
| 197 |
+
|
| 198 |
+
# Add scikit-learn modules to execution environment
|
| 199 |
+
import sklearn
|
| 200 |
+
import numpy as np
|
| 201 |
+
|
| 202 |
+
# Import scikit-learn components
|
| 203 |
+
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso # type: ignore
|
| 204 |
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier # type: ignore
|
| 205 |
+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
| 206 |
+
from sklearn.svm import SVC, SVR
|
| 207 |
+
from sklearn.naive_bayes import GaussianNB
|
| 208 |
+
from sklearn.decomposition import PCA
|
| 209 |
+
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
|
| 210 |
+
from sklearn.cluster import KMeans, DBSCAN
|
| 211 |
+
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
|
| 212 |
+
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
|
| 213 |
+
from sklearn.metrics import (
|
| 214 |
+
accuracy_score, confusion_matrix, classification_report,
|
| 215 |
+
mean_squared_error, r2_score, mean_absolute_error, silhouette_score
|
| 216 |
+
)
|
| 217 |
+
from sklearn.pipeline import Pipeline
|
| 218 |
+
from sklearn.impute import SimpleImputer
|
| 219 |
+
|
| 220 |
+
# Update execution globals with all ML components
|
| 221 |
+
exec_globals.update({
|
| 222 |
+
"plotly_figures": [],
|
| 223 |
+
"images_dir": st.session_state.images_dir,
|
| 224 |
+
"np": np,
|
| 225 |
+
# Linear models
|
| 226 |
+
"LinearRegression": LinearRegression,
|
| 227 |
+
"LogisticRegression": LogisticRegression,
|
| 228 |
+
"Ridge": Ridge,
|
| 229 |
+
"Lasso": Lasso,
|
| 230 |
+
# Tree-based models
|
| 231 |
+
"DecisionTreeClassifier": DecisionTreeClassifier,
|
| 232 |
+
"DecisionTreeRegressor": DecisionTreeRegressor,
|
| 233 |
+
"RandomForestClassifier": RandomForestClassifier,
|
| 234 |
+
"RandomForestRegressor": RandomForestRegressor,
|
| 235 |
+
"GradientBoostingClassifier": GradientBoostingClassifier,
|
| 236 |
+
# SVM models
|
| 237 |
+
"SVC": SVC,
|
| 238 |
+
"SVR": SVR,
|
| 239 |
+
# Other models
|
| 240 |
+
"GaussianNB": GaussianNB,
|
| 241 |
+
"PCA": PCA,
|
| 242 |
+
"KNeighborsClassifier": KNeighborsClassifier,
|
| 243 |
+
"KNeighborsRegressor": KNeighborsRegressor,
|
| 244 |
+
"KMeans": KMeans,
|
| 245 |
+
"DBSCAN": DBSCAN,
|
| 246 |
+
# Preprocessing
|
| 247 |
+
"StandardScaler": StandardScaler,
|
| 248 |
+
"MinMaxScaler": MinMaxScaler,
|
| 249 |
+
"OneHotEncoder": OneHotEncoder,
|
| 250 |
+
"SimpleImputer": SimpleImputer,
|
| 251 |
+
# Model selection and evaluation
|
| 252 |
+
"train_test_split": train_test_split,
|
| 253 |
+
"cross_val_score": cross_val_score,
|
| 254 |
+
"GridSearchCV": GridSearchCV,
|
| 255 |
+
"accuracy_score": accuracy_score,
|
| 256 |
+
"confusion_matrix": confusion_matrix,
|
| 257 |
+
"classification_report": classification_report,
|
| 258 |
+
"mean_squared_error": mean_squared_error,
|
| 259 |
+
"r2_score": r2_score,
|
| 260 |
+
"mean_absolute_error": mean_absolute_error,
|
| 261 |
+
"silhouette_score": silhouette_score,
|
| 262 |
+
# Pipeline
|
| 263 |
+
"Pipeline": Pipeline
|
| 264 |
+
})
|
| 265 |
+
|
| 266 |
+
exec(python_code, exec_globals)
|
| 267 |
+
|
| 268 |
+
st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()})
|
| 269 |
+
|
| 270 |
+
# Get the captured stdout
|
| 271 |
+
output = sys.stdout.getvalue()
|
| 272 |
+
|
| 273 |
+
# Restore stdout
|
| 274 |
+
sys.stdout = old_stdout
|
| 275 |
+
|
| 276 |
+
updated_state = {
|
| 277 |
+
"intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}],
|
| 278 |
+
"current_variables": st.session_state.persistent_vars
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']:
|
| 282 |
+
exec(plotly_saving_code, exec_globals)
|
| 283 |
+
|
| 284 |
+
# Check if any images were created
|
| 285 |
+
new_image_folder_contents = os.listdir(st.session_state.images_dir)
|
| 286 |
+
new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files]
|
| 287 |
+
|
| 288 |
+
if new_image_files:
|
| 289 |
+
updated_state["output_image_paths"] = new_image_files
|
| 290 |
+
st.session_state.persistent_vars["plotly_figures"] = []
|
| 291 |
+
return output, updated_state
|
| 292 |
+
|
| 293 |
+
except Exception as e:
|
| 294 |
+
sys.stdout = old_stdout # Restore stdout in case of error
|
| 295 |
+
print(f"Error in complete_python_task: {str(e)}")
|
| 296 |
+
return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]}
|
| 297 |
+
|
| 298 |
+
# Function to initialize the LLM based on selected provider and model
|
| 299 |
+
def initialize_llm():
|
| 300 |
+
api_key = st.session_state.api_key
|
| 301 |
+
model = st.session_state.selected_model
|
| 302 |
+
|
| 303 |
+
if not api_key:
|
| 304 |
+
return None
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
if st.session_state.ai_provider == "openai":
|
| 308 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
| 309 |
+
return ChatOpenAI(model=model, temperature=0)
|
| 310 |
+
elif st.session_state.ai_provider == "groq":
|
| 311 |
+
os.environ["GROQ_API_KEY"] = api_key
|
| 312 |
+
# For Groq, set the base URL and use the model
|
| 313 |
+
from langchain_groq import ChatGroq
|
| 314 |
+
return ChatGroq(model=model, temperature=0)
|
| 315 |
+
except Exception as e:
|
| 316 |
+
print(f"Error initializing LLM: {str(e)}")
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
# Set up the tools
|
| 320 |
+
tools = [complete_python_task]
|
| 321 |
+
tool_executor = ToolExecutor(tools)
|
| 322 |
+
|
| 323 |
+
# Load the prompt template
|
| 324 |
+
chat_template = ChatPromptTemplate.from_messages([
|
| 325 |
+
("system", SYSTEM_PROMPT),
|
| 326 |
+
("placeholder", "{messages}"),
|
| 327 |
+
])
|
| 328 |
+
|
| 329 |
+
def create_data_summary(state: AgentState) -> str:
|
| 330 |
+
summary = ""
|
| 331 |
+
variables = []
|
| 332 |
+
|
| 333 |
+
# Add sample data for each dataset
|
| 334 |
+
for d in state.get("input_data", []):
|
| 335 |
+
var_name = d.get("variable_name")
|
| 336 |
+
if var_name:
|
| 337 |
+
|
| 338 |
+
variables.append(var_name)
|
| 339 |
+
summary += f"\n\nVariable: {var_name}\n"
|
| 340 |
+
summary += f"Description: {d.get('data_description', 'No description')}\n"
|
| 341 |
+
|
| 342 |
+
# Add sample data if available
|
| 343 |
+
if var_name in st.session_state.in_memory_datasets:
|
| 344 |
+
df = st.session_state.in_memory_datasets[var_name]
|
| 345 |
+
summary += "\nSample Data (first 5 rows):\n"
|
| 346 |
+
summary += df.head(5).to_string()
|
| 347 |
+
|
| 348 |
+
if "current_variables" in state:
|
| 349 |
+
remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")]
|
| 350 |
+
|
| 351 |
+
for v in remaining_variables:
|
| 352 |
+
|
| 353 |
+
var_value = state["current_variables"].get(v)
|
| 354 |
+
|
| 355 |
+
if isinstance(var_value, pd.DataFrame):
|
| 356 |
+
summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})"
|
| 357 |
+
else:
|
| 358 |
+
summary += f"\n\nVariable: {v}"
|
| 359 |
+
return summary
|
| 360 |
+
|
| 361 |
+
def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]:
|
| 362 |
+
"""Determine if we should route to tools or end the chain"""
|
| 363 |
+
if messages := state.get("messages", []):
|
| 364 |
+
ai_message = messages[-1]
|
| 365 |
+
else:
|
| 366 |
+
raise ValueError(f"No messages found in input state to tool_edge: {state}")
|
| 367 |
+
|
| 368 |
+
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
|
| 369 |
+
return "tools"
|
| 370 |
+
|
| 371 |
+
return "__end__"
|
| 372 |
+
|
| 373 |
+
def call_model(state: AgentState):
|
| 374 |
+
"""Call the LLM to get a response"""
|
| 375 |
+
current_data_template = """The following data is available:\n{data_summary}"""
|
| 376 |
+
current_data_message = HumanMessage(
|
| 377 |
+
content=current_data_template.format(data_summary=create_data_summary(state))
|
| 378 |
+
)
|
| 379 |
+
messages = [current_data_message] + state["messages"]
|
| 380 |
+
|
| 381 |
+
# Get the initialized LLM
|
| 382 |
+
llm = initialize_llm()
|
| 383 |
+
if llm is None:
|
| 384 |
+
return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]}
|
| 385 |
+
|
| 386 |
+
# Create the model with bound tools
|
| 387 |
+
model = llm.bind_tools(tools)
|
| 388 |
+
model = chat_template | model
|
| 389 |
+
|
| 390 |
+
llm_outputs = model.invoke({"messages": messages})
|
| 391 |
+
return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]}
|
| 392 |
+
|
| 393 |
+
def call_tools(state: AgentState):
|
| 394 |
+
"""Execute tools called by the LLM"""
|
| 395 |
+
last_message = state["messages"][-1]
|
| 396 |
+
tool_invocations = []
|
| 397 |
+
|
| 398 |
+
if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'):
|
| 399 |
+
tool_invocations = [
|
| 400 |
+
ToolInvocation(
|
| 401 |
+
tool=tool_call["name"],
|
| 402 |
+
tool_input={**tool_call["args"], "graph_state": state}
|
| 403 |
+
) for tool_call in last_message.tool_calls
|
| 404 |
+
]
|
| 405 |
+
responses = tool_executor.batch(tool_invocations, return_exceptions=True)
|
| 406 |
+
|
| 407 |
+
tool_messages = []
|
| 408 |
+
state_updates = {}
|
| 409 |
+
|
| 410 |
+
for tc, response in zip(last_message.tool_calls, responses):
|
| 411 |
+
if isinstance(response, Exception):
|
| 412 |
+
print(f"Exception in tool execution: {str(response)}")
|
| 413 |
+
tool_messages.append(ToolMessage(
|
| 414 |
+
content=f"Error: {str(response)}",
|
| 415 |
+
name=tc["name"],
|
| 416 |
+
tool_call_id=tc["id"]
|
| 417 |
+
))
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
message, updates = response
|
| 421 |
+
tool_messages.append(ToolMessage(
|
| 422 |
+
content=str(message),
|
| 423 |
+
name=tc["name"],
|
| 424 |
+
tool_call_id=tc["id"]
|
| 425 |
+
))
|
| 426 |
+
|
| 427 |
+
# Merge updates instead of overwriting
|
| 428 |
+
for key, value in updates.items():
|
| 429 |
+
if key in state_updates:
|
| 430 |
+
if isinstance(value, list) and isinstance(state_updates[key], list):
|
| 431 |
+
state_updates[key].extend(value)
|
| 432 |
+
elif isinstance(value, dict) and isinstance(state_updates[key], dict):
|
| 433 |
+
state_updates[key].update(value)
|
| 434 |
+
else:
|
| 435 |
+
state_updates[key] = value
|
| 436 |
+
else:
|
| 437 |
+
state_updates[key] = value
|
| 438 |
+
|
| 439 |
+
if 'messages' not in state_updates:
|
| 440 |
+
state_updates["messages"] = []
|
| 441 |
+
|
| 442 |
+
state_updates["messages"] = tool_messages
|
| 443 |
+
return state_updates
|
| 444 |
+
|
| 445 |
+
# Set up the graph
|
| 446 |
+
workflow = StateGraph(AgentState)
|
| 447 |
+
workflow.add_node("agent", call_model)
|
| 448 |
+
workflow.add_node("tools", call_tools)
|
| 449 |
+
workflow.add_conditional_edges(
|
| 450 |
+
"agent",
|
| 451 |
+
route_to_tools,
|
| 452 |
+
{
|
| 453 |
+
"tools": "tools",
|
| 454 |
+
"__end__": END
|
| 455 |
+
}
|
| 456 |
+
)
|
| 457 |
+
workflow.add_edge("tools", "agent")
|
| 458 |
+
workflow.set_entry_point("agent")
|
| 459 |
+
|
| 460 |
+
chain = workflow.compile()
|
| 461 |
+
|
| 462 |
+
def process_file_upload(files):
|
| 463 |
+
"""Process uploaded files and return dataframe previews and column names"""
|
| 464 |
+
st.session_state.in_memory_datasets = {} # Clear previous datasets
|
| 465 |
+
st.session_state.dataset_metadata_list = [] # Clear previous metadata
|
| 466 |
+
st.session_state.persistent_vars.clear() # Clear persistent variables for new session
|
| 467 |
+
|
| 468 |
+
if not files:
|
| 469 |
+
return "No files uploaded.", [], ["No columns available"]
|
| 470 |
+
|
| 471 |
+
results = []
|
| 472 |
+
all_columns = [] # Track all columns from all datasets
|
| 473 |
+
|
| 474 |
+
for file in files:
|
| 475 |
+
try:
|
| 476 |
+
# Use file object directly
|
| 477 |
+
if file.name.endswith('.csv'):
|
| 478 |
+
df = pd.read_csv(file)
|
| 479 |
+
elif file.name.endswith(('.xls', '.xlsx')):
|
| 480 |
+
df = pd.read_excel(file)
|
| 481 |
+
else:
|
| 482 |
+
results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.")
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower()
|
| 486 |
+
st.session_state.in_memory_datasets[var_name] = df
|
| 487 |
+
|
| 488 |
+
# Collect all columns
|
| 489 |
+
all_columns.extend(df.columns.tolist())
|
| 490 |
+
|
| 491 |
+
# Create dataset metadata
|
| 492 |
+
dataset_metadata = {
|
| 493 |
+
"variable_name": var_name,
|
| 494 |
+
"data_path": "in_memory",
|
| 495 |
+
"data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}",
|
| 496 |
+
"original_filename": file.name
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
st.session_state.dataset_metadata_list.append(dataset_metadata)
|
| 500 |
+
|
| 501 |
+
# Return preview of the dataset
|
| 502 |
+
preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n"
|
| 503 |
+
preview += df.head(10).to_markdown()
|
| 504 |
+
results.append(preview)
|
| 505 |
+
print(f"Successfully processed {file.name}")
|
| 506 |
+
|
| 507 |
+
except Exception as e:
|
| 508 |
+
print(f"Error processing {file.name}: {str(e)}")
|
| 509 |
+
results.append(f"Error processing {file.name}: {str(e)}")
|
| 510 |
+
|
| 511 |
+
# Get unique columns
|
| 512 |
+
unique_columns = []
|
| 513 |
+
seen = set()
|
| 514 |
+
|
| 515 |
+
for col in all_columns:
|
| 516 |
+
if col not in seen:
|
| 517 |
+
seen.add(col)
|
| 518 |
+
unique_columns.append(col)
|
| 519 |
+
|
| 520 |
+
if not unique_columns:
|
| 521 |
+
unique_columns = ["No columns available"]
|
| 522 |
+
|
| 523 |
+
print(f"Found {len(unique_columns)} unique columns across datasets")
|
| 524 |
+
return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns
|
| 525 |
+
|
| 526 |
+
def get_columns():
|
| 527 |
+
"""Directly gets columns from in-memory datasets"""
|
| 528 |
+
all_columns = []
|
| 529 |
+
|
| 530 |
+
for var_name, df in st.session_state.in_memory_datasets.items():
|
| 531 |
+
if isinstance(df, pd.DataFrame):
|
| 532 |
+
all_columns.extend(df.columns.tolist())
|
| 533 |
+
|
| 534 |
+
# Remove duplicates while preserving order
|
| 535 |
+
unique_columns = []
|
| 536 |
+
seen = set()
|
| 537 |
+
|
| 538 |
+
for col in all_columns:
|
| 539 |
+
if col not in seen:
|
| 540 |
+
seen.add(col)
|
| 541 |
+
unique_columns.append(col)
|
| 542 |
+
|
| 543 |
+
if not unique_columns:
|
| 544 |
+
unique_columns = ["No columns available"]
|
| 545 |
+
|
| 546 |
+
print(f"Populating dropdowns with {len(unique_columns)} columns")
|
| 547 |
+
return unique_columns
|
| 548 |
+
|
| 549 |
+
# === FUNCTIONS ===
|
| 550 |
+
import openai
|
| 551 |
+
import pandas as pd
|
| 552 |
+
import json
|
| 553 |
+
import re
|
| 554 |
+
|
| 555 |
+
def standard_clean(df):
|
| 556 |
+
df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns]
|
| 557 |
+
df.drop_duplicates(inplace=True)
|
| 558 |
+
df.dropna(axis=1, how='all', inplace=True)
|
| 559 |
+
df.dropna(axis=0, how='all', inplace=True)
|
| 560 |
+
for col in df.select_dtypes(include='object').columns:
|
| 561 |
+
df[col] = df[col].astype(str).str.strip()
|
| 562 |
+
return df
|
| 563 |
+
|
| 564 |
+
def query_openai(prompt):
|
| 565 |
+
try:
|
| 566 |
+
# Use the configured API key and model from session state
|
| 567 |
+
api_key = st.session_state.api_key
|
| 568 |
+
model = st.session_state.selected_model
|
| 569 |
+
|
| 570 |
+
if st.session_state.ai_provider == "openai":
|
| 571 |
+
client = openai.OpenAI(api_key=api_key)
|
| 572 |
+
response = client.chat.completions.create(
|
| 573 |
+
model=model,
|
| 574 |
+
messages=[{"role": "user", "content": prompt}],
|
| 575 |
+
temperature=0.7
|
| 576 |
+
)
|
| 577 |
+
return response.choices[0].message.content
|
| 578 |
+
elif st.session_state.ai_provider == "groq":
|
| 579 |
+
from groq import Groq
|
| 580 |
+
client = Groq(api_key=api_key)
|
| 581 |
+
response = client.chat.completions.create(
|
| 582 |
+
model=model,
|
| 583 |
+
messages=[{"role": "user", "content": prompt}],
|
| 584 |
+
temperature=0.7
|
| 585 |
+
)
|
| 586 |
+
return response.choices[0].message.content
|
| 587 |
+
except Exception as e:
|
| 588 |
+
print(f"API Error: {e}")
|
| 589 |
+
return "{}"
|
| 590 |
+
|
| 591 |
+
def llm_suggest_cleaning(df):
|
| 592 |
+
sample = df.head(10).to_csv(index=False)
|
| 593 |
+
prompt = f"""
|
| 594 |
+
You are a professional data wrangler. Below is a sample of a messy dataset.
|
| 595 |
+
|
| 596 |
+
Return a Python dictionary with the following keys:
|
| 597 |
+
|
| 598 |
+
1. rename_columns – fix unclear or inconsistent column names
|
| 599 |
+
2. convert_types – correct datatypes: int, float, str, or date
|
| 600 |
+
3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0
|
| 601 |
+
4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes)
|
| 602 |
+
|
| 603 |
+
Do not drop any rows or columns. Your output must be a valid Python dict.
|
| 604 |
+
|
| 605 |
+
Example:
|
| 606 |
+
{{
|
| 607 |
+
"rename_columns": {{"dob": "date_of_birth"}},
|
| 608 |
+
"convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}},
|
| 609 |
+
"fill_missing": {{"gender": "mode", "salary": -1}},
|
| 610 |
+
"value_map": {{
|
| 611 |
+
"gender": {{"M": "Male", "F": "Female"}},
|
| 612 |
+
"subscribed": {{"Y": "Yes", "N": "No"}}
|
| 613 |
+
}}
|
| 614 |
+
}}
|
| 615 |
+
Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning.
|
| 616 |
+
Sample data:
|
| 617 |
+
{sample}
|
| 618 |
+
"""
|
| 619 |
+
raw_response = query_openai(prompt)
|
| 620 |
+
try:
|
| 621 |
+
suggestions = eval(raw_response)
|
| 622 |
+
return suggestions
|
| 623 |
+
except:
|
| 624 |
+
print("Could not parse suggestions.")
|
| 625 |
+
return {
|
| 626 |
+
"rename_columns": {},
|
| 627 |
+
"convert_types": {},
|
| 628 |
+
"fill_missing": {},
|
| 629 |
+
"value_map": {}
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
def apply_suggestions(df, suggestions):
|
| 633 |
+
df.rename(columns=suggestions.get("rename_columns", {}), inplace=True)
|
| 634 |
+
|
| 635 |
+
for col, dtype in suggestions.get("convert_types", {}).items():
|
| 636 |
+
if col not in df.columns:
|
| 637 |
+
continue
|
| 638 |
+
try:
|
| 639 |
+
if dtype == "int":
|
| 640 |
+
df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64")
|
| 641 |
+
elif dtype == "float":
|
| 642 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 643 |
+
elif dtype == "str":
|
| 644 |
+
df[col] = df[col].astype(str)
|
| 645 |
+
elif dtype == "date":
|
| 646 |
+
df[col] = pd.to_datetime(df[col], errors='coerce')
|
| 647 |
+
except:
|
| 648 |
+
print(f"Failed to convert {col} to {dtype}")
|
| 649 |
+
|
| 650 |
+
for col, method in suggestions.get("fill_missing", {}).items():
|
| 651 |
+
if col not in df.columns:
|
| 652 |
+
continue
|
| 653 |
+
try:
|
| 654 |
+
if method == "mean":
|
| 655 |
+
df[col].fillna(df[col].mean(), inplace=True)
|
| 656 |
+
elif method == "median":
|
| 657 |
+
df[col].fillna(df[col].median(), inplace=True)
|
| 658 |
+
elif method == "mode":
|
| 659 |
+
df[col].fillna(df[col].mode().iloc[0], inplace=True)
|
| 660 |
+
elif isinstance(method, str):
|
| 661 |
+
df[col].fillna(method, inplace=True)
|
| 662 |
+
except:
|
| 663 |
+
print(f"Could not fill missing values for {col}")
|
| 664 |
+
|
| 665 |
+
for col, mapping in suggestions.get("value_map", {}).items():
|
| 666 |
+
if col in df.columns:
|
| 667 |
+
try:
|
| 668 |
+
df[col] = df[col].replace(mapping)
|
| 669 |
+
except:
|
| 670 |
+
print(f"Could not map values in {col}")
|
| 671 |
+
|
| 672 |
+
return df
|
| 673 |
+
|
| 674 |
+
def capture_dashboard_screenshot():
|
| 675 |
+
"""Capture the entire dashboard as a single image"""
|
| 676 |
+
try:
|
| 677 |
+
# Create a figure that combines all dashboard plots
|
| 678 |
+
import plotly.graph_objects as go
|
| 679 |
+
from plotly.subplots import make_subplots
|
| 680 |
+
|
| 681 |
+
# Create a 2x2 subplot
|
| 682 |
+
fig = make_subplots(rows=2, cols=2,
|
| 683 |
+
subplot_titles=["Visualization 1", "Visualization 2",
|
| 684 |
+
"Visualization 3", "Visualization 4"])
|
| 685 |
+
|
| 686 |
+
# Add each plot from the dashboard to the combined figure
|
| 687 |
+
for i, plot in enumerate(st.session_state.dashboard_plots):
|
| 688 |
+
if plot is not None:
|
| 689 |
+
row = (i // 2) + 1
|
| 690 |
+
col = (i % 2) + 1
|
| 691 |
+
|
| 692 |
+
# Extract traces from the original figure and add to our subplot
|
| 693 |
+
for trace in plot.data:
|
| 694 |
+
fig.add_trace(trace, row=row, col=col)
|
| 695 |
+
|
| 696 |
+
# Copy layout properties for each subplot
|
| 697 |
+
for axis_type in ['xaxis', 'yaxis']:
|
| 698 |
+
axis_name = f"{axis_type}{i+1 if i > 0 else ''}"
|
| 699 |
+
subplot_name = f"{axis_type}{row}{col}"
|
| 700 |
+
|
| 701 |
+
# Copy axis properties if they exist
|
| 702 |
+
if hasattr(plot.layout, axis_name):
|
| 703 |
+
axis_props = getattr(plot.layout, axis_name)
|
| 704 |
+
fig.update_layout({subplot_name: axis_props})
|
| 705 |
+
|
| 706 |
+
# Update layout for better appearance
|
| 707 |
+
fig.update_layout(
|
| 708 |
+
height=800,
|
| 709 |
+
width=1000,
|
| 710 |
+
title_text="Dashboard Overview",
|
| 711 |
+
showlegend=False,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Save to a temporary file
|
| 715 |
+
dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png"
|
| 716 |
+
fig.write_image(dashboard_path, scale=2) # Higher scale for better resolution
|
| 717 |
+
return dashboard_path
|
| 718 |
+
|
| 719 |
+
except Exception as e:
|
| 720 |
+
import traceback
|
| 721 |
+
print(f"Error capturing dashboard: {str(e)}")
|
| 722 |
+
print(traceback.format_exc())
|
| 723 |
+
return None
|
| 724 |
+
|
| 725 |
+
def generate_enhanced_pdf_report():
|
| 726 |
+
"""Generate an enhanced PDF report with proper handling of base64 image data"""
|
| 727 |
+
try:
|
| 728 |
+
# Create a buffer for the PDF
|
| 729 |
+
buffer = io.BytesIO()
|
| 730 |
+
|
| 731 |
+
# Create the PDF document
|
| 732 |
+
doc = SimpleDocTemplate(buffer, pagesize=letter,
|
| 733 |
+
leftMargin=36, rightMargin=36,
|
| 734 |
+
topMargin=36, bottomMargin=36)
|
| 735 |
+
|
| 736 |
+
# Create custom styles with better formatting
|
| 737 |
+
styles = getSampleStyleSheet()
|
| 738 |
+
|
| 739 |
+
# Add custom styles with improved formatting
|
| 740 |
+
styles.add(ParagraphStyle(
|
| 741 |
+
name='ReportTitle',
|
| 742 |
+
parent=styles['Heading1'],
|
| 743 |
+
fontSize=24,
|
| 744 |
+
alignment=1, # Centered
|
| 745 |
+
spaceAfter=20,
|
| 746 |
+
textColor='#2C3E50' # Dark blue color
|
| 747 |
+
))
|
| 748 |
+
|
| 749 |
+
styles.add(ParagraphStyle(
|
| 750 |
+
name='SectionHeader',
|
| 751 |
+
parent=styles['Heading2'],
|
| 752 |
+
fontSize=16,
|
| 753 |
+
spaceBefore=15,
|
| 754 |
+
spaceAfter=10,
|
| 755 |
+
textColor='#2C3E50',
|
| 756 |
+
borderWidth=1,
|
| 757 |
+
borderColor='#95A5A6',
|
| 758 |
+
borderPadding=5,
|
| 759 |
+
borderRadius=5
|
| 760 |
+
))
|
| 761 |
+
|
| 762 |
+
styles.add(ParagraphStyle(
|
| 763 |
+
name='SubHeader',
|
| 764 |
+
parent=styles['Heading3'],
|
| 765 |
+
fontSize=14,
|
| 766 |
+
spaceBefore=10,
|
| 767 |
+
spaceAfter=8,
|
| 768 |
+
textColor='#34495E',
|
| 769 |
+
fontWeight='bold'
|
| 770 |
+
))
|
| 771 |
+
styles.add(ParagraphStyle(
|
| 772 |
+
name='UserMessage',
|
| 773 |
+
parent=styles['Normal'],
|
| 774 |
+
fontSize=11,
|
| 775 |
+
leftIndent=10,
|
| 776 |
+
spaceBefore=8,
|
| 777 |
+
spaceAfter=4
|
| 778 |
+
))
|
| 779 |
+
|
| 780 |
+
styles.add(ParagraphStyle(
|
| 781 |
+
name='AssistantMessage',
|
| 782 |
+
parent=styles['Normal'],
|
| 783 |
+
fontSize=11,
|
| 784 |
+
leftIndent=10,
|
| 785 |
+
spaceBefore=4,
|
| 786 |
+
spaceAfter=12,
|
| 787 |
+
textColor='#2980B9'
|
| 788 |
+
))
|
| 789 |
+
|
| 790 |
+
styles.add(ParagraphStyle(
|
| 791 |
+
name='Timestamp',
|
| 792 |
+
parent=styles['Italic'],
|
| 793 |
+
fontSize=10,
|
| 794 |
+
textColor='#7F8C8D',
|
| 795 |
+
alignment=2 # Right aligned
|
| 796 |
+
))
|
| 797 |
+
|
| 798 |
+
# Create the document content
|
| 799 |
+
elements = []
|
| 800 |
+
|
| 801 |
+
# Add title
|
| 802 |
+
elements.append(Paragraph('Data Analysis Report', styles['ReportTitle']))
|
| 803 |
+
|
| 804 |
+
# Add timestamp
|
| 805 |
+
elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
|
| 806 |
+
styles['Timestamp']))
|
| 807 |
+
elements.append(Spacer(1, 0.5*inch))
|
| 808 |
+
|
| 809 |
+
# Add conversation history with better formatting
|
| 810 |
+
elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader']))
|
| 811 |
+
|
| 812 |
+
if st.session_state.chat_history:
|
| 813 |
+
for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
|
| 814 |
+
# Format user message with proper styling
|
| 815 |
+
elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader']))
|
| 816 |
+
user_msg_formatted = user_msg.replace('\n', '<br/>')
|
| 817 |
+
elements.append(Paragraph(user_msg_formatted, styles['UserMessage']))
|
| 818 |
+
|
| 819 |
+
# Process assistant message to handle visualization
|
| 820 |
+
# Look for markdown image syntax with base64 data
|
| 821 |
+
base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)'
|
| 822 |
+
|
| 823 |
+
# Check if the message contains visualizations
|
| 824 |
+
if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg):
|
| 825 |
+
# Split the message at the Visualizations header if it exists
|
| 826 |
+
if '### Visualizations' in assistant_msg:
|
| 827 |
+
parts = assistant_msg.split('### Visualizations', 1)
|
| 828 |
+
text_part = parts[0]
|
| 829 |
+
viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else ""
|
| 830 |
+
else:
|
| 831 |
+
# If no header but still has visualization
|
| 832 |
+
match = re.search(base64_pattern, assistant_msg)
|
| 833 |
+
text_part = assistant_msg[:match.start()]
|
| 834 |
+
viz_part = assistant_msg[match.start():]
|
| 835 |
+
|
| 836 |
+
# Format the text part
|
| 837 |
+
elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
|
| 838 |
+
text_part = text_part.replace('\n', '<br/>')
|
| 839 |
+
elements.append(Paragraph(text_part, styles['AssistantMessage']))
|
| 840 |
+
|
| 841 |
+
# Process visualizations
|
| 842 |
+
matches = re.findall(base64_pattern, viz_part)
|
| 843 |
+
for j, base64_data in enumerate(matches):
|
| 844 |
+
try:
|
| 845 |
+
# Decode the base64 image
|
| 846 |
+
image_data = base64.b64decode(base64_data)
|
| 847 |
+
|
| 848 |
+
# Create a temporary file for the image
|
| 849 |
+
temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png"
|
| 850 |
+
|
| 851 |
+
with open(temp_img_path, 'wb') as f:
|
| 852 |
+
f.write(image_data)
|
| 853 |
+
|
| 854 |
+
# Add the image to the PDF
|
| 855 |
+
elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader']))
|
| 856 |
+
elements.append(Spacer(1, 0.1*inch))
|
| 857 |
+
img = Image(temp_img_path, width=6*inch, height=4*inch)
|
| 858 |
+
elements.append(img)
|
| 859 |
+
elements.append(Spacer(1, 0.2*inch))
|
| 860 |
+
except Exception as e:
|
| 861 |
+
print(f"Error processing base64 image: {str(e)}")
|
| 862 |
+
elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]",
|
| 863 |
+
styles['Normal']))
|
| 864 |
+
else:
|
| 865 |
+
# No visualizations, just format the text
|
| 866 |
+
elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
|
| 867 |
+
assistant_msg_formatted = assistant_msg.replace('\n', '<br/>')
|
| 868 |
+
if len(assistant_msg_formatted) > 1500:
|
| 869 |
+
assistant_msg_formatted = assistant_msg_formatted[:1500] + '...'
|
| 870 |
+
elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage']))
|
| 871 |
+
|
| 872 |
+
elements.append(Spacer(1, 0.2*inch))
|
| 873 |
+
else:
|
| 874 |
+
elements.append(Paragraph('No conversation history available.', styles['Normal']))
|
| 875 |
+
|
| 876 |
+
# Force a page break before the dashboard
|
| 877 |
+
elements.append(PageBreak())
|
| 878 |
+
|
| 879 |
+
# Add dashboard section header
|
| 880 |
+
elements.append(Paragraph('Dashboard Overview', styles['SectionHeader']))
|
| 881 |
+
elements.append(Spacer(1, 0.2*inch))
|
| 882 |
+
|
| 883 |
+
# Capture the dashboard as a single image
|
| 884 |
+
dashboard_img_path = capture_dashboard_screenshot()
|
| 885 |
+
|
| 886 |
+
if dashboard_img_path:
|
| 887 |
+
# Calculate available width
|
| 888 |
+
available_width = doc.width
|
| 889 |
+
|
| 890 |
+
# Create PIL image to get dimensions
|
| 891 |
+
pil_img = PILImage.open(dashboard_img_path)
|
| 892 |
+
img_width, img_height = pil_img.size
|
| 893 |
+
|
| 894 |
+
# Calculate scaling factor to fit within page width
|
| 895 |
+
scale_factor = available_width / img_width
|
| 896 |
+
|
| 897 |
+
# Calculate new height based on aspect ratio
|
| 898 |
+
new_height = img_height * scale_factor
|
| 899 |
+
|
| 900 |
+
# Add the image with scaled dimensions
|
| 901 |
+
img = Image(dashboard_img_path, width=available_width, height=new_height)
|
| 902 |
+
elements.append(img)
|
| 903 |
+
else:
|
| 904 |
+
# Fallback: Add individual plots if combined dashboard fails
|
| 905 |
+
plot_count = 0
|
| 906 |
+
for i, plot in enumerate(st.session_state.dashboard_plots):
|
| 907 |
+
if plot is not None:
|
| 908 |
+
plot_count += 1
|
| 909 |
+
|
| 910 |
+
# Convert plotly figure to image
|
| 911 |
+
img_bytes = io.BytesIO()
|
| 912 |
+
plot.write_image(img_bytes, format='png', width=500, height=300)
|
| 913 |
+
img_bytes.seek(0)
|
| 914 |
+
|
| 915 |
+
# Create a temporary file for the image
|
| 916 |
+
temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
|
| 917 |
+
|
| 918 |
+
with open(temp_img_path, 'wb') as f:
|
| 919 |
+
f.write(img_bytes.getvalue())
|
| 920 |
+
|
| 921 |
+
# Add to PDF with appropriate caption and formatting
|
| 922 |
+
elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader']))
|
| 923 |
+
elements.append(Spacer(1, 0.1*inch))
|
| 924 |
+
|
| 925 |
+
# Add the image with proper scaling
|
| 926 |
+
img = Image(temp_img_path, width=6.5*inch, height=4*inch)
|
| 927 |
+
elements.append(img)
|
| 928 |
+
elements.append(Spacer(1, 0.3*inch))
|
| 929 |
+
|
| 930 |
+
if plot_count == 0:
|
| 931 |
+
elements.append(Paragraph('No visualizations have been added to the dashboard.',
|
| 932 |
+
styles['Normal']))
|
| 933 |
+
|
| 934 |
+
# Build the PDF with improved formatting
|
| 935 |
+
doc.build(elements)
|
| 936 |
+
|
| 937 |
+
# Get the value of the buffer
|
| 938 |
+
pdf_value = buffer.getvalue()
|
| 939 |
+
buffer.close()
|
| 940 |
+
|
| 941 |
+
return pdf_value
|
| 942 |
+
|
| 943 |
+
except Exception as e:
|
| 944 |
+
import traceback
|
| 945 |
+
print(f"Error generating enhanced PDF report: {str(e)}")
|
| 946 |
+
print(traceback.format_exc())
|
| 947 |
+
return None
|
| 948 |
+
|
| 949 |
+
def chat_with_workflow(message, history, dataset_info):
|
| 950 |
+
"""Send user query to the workflow and get response"""
|
| 951 |
+
|
| 952 |
+
if not dataset_info:
|
| 953 |
+
return "Please upload at least one dataset before asking questions."
|
| 954 |
+
|
| 955 |
+
# Check if we have a valid API key and model
|
| 956 |
+
if not st.session_state.api_key:
|
| 957 |
+
return "Please set up your API key and model in the Settings tab before chatting."
|
| 958 |
+
|
| 959 |
+
print(f"Chat with workflow called with {len(dataset_info)} datasets")
|
| 960 |
+
|
| 961 |
+
try:
|
| 962 |
+
# Extract chat history for context (last 3 exchanges)
|
| 963 |
+
max_history = 3
|
| 964 |
+
previous_messages = []
|
| 965 |
+
|
| 966 |
+
if history:
|
| 967 |
+
start_idx = max(0, len(history) - max_history)
|
| 968 |
+
recent_history = history[start_idx:]
|
| 969 |
+
|
| 970 |
+
for exchange in recent_history:
|
| 971 |
+
if exchange[0]: # User message
|
| 972 |
+
previous_messages.append(HumanMessage(content=exchange[0]))
|
| 973 |
+
if exchange[1]: # AI response
|
| 974 |
+
previous_messages.append(AIMessage(content=exchange[1]))
|
| 975 |
+
|
| 976 |
+
# Initialize the workflow state
|
| 977 |
+
state = AgentState(
|
| 978 |
+
messages=previous_messages + [HumanMessage(content=message)],
|
| 979 |
+
input_data=dataset_info,
|
| 980 |
+
intermediate_outputs=[],
|
| 981 |
+
current_variables=st.session_state.persistent_vars,
|
| 982 |
+
output_image_paths=[]
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
# Execute the workflow
|
| 986 |
+
print("Executing workflow...")
|
| 987 |
+
result = chain.invoke(state)
|
| 988 |
+
print("Workflow execution completed")
|
| 989 |
+
|
| 990 |
+
# Extract messages from the result
|
| 991 |
+
messages = result["messages"]
|
| 992 |
+
|
| 993 |
+
# Format the response - only get the latest response
|
| 994 |
+
response = ""
|
| 995 |
+
if messages:
|
| 996 |
+
latest_message = messages[-1] # Get only the last message
|
| 997 |
+
if hasattr(latest_message, "content"):
|
| 998 |
+
content = latest_message.content
|
| 999 |
+
|
| 1000 |
+
# Clean up the response
|
| 1001 |
+
# Remove any instances where the user's message is repeated
|
| 1002 |
+
if message in content:
|
| 1003 |
+
content = content.split(message)[-1].strip()
|
| 1004 |
+
|
| 1005 |
+
# Remove any chat history markers
|
| 1006 |
+
content_lines = content.split('\n')
|
| 1007 |
+
filtered_lines = [line for line in content_lines
|
| 1008 |
+
if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))]
|
| 1009 |
+
content = '\n'.join(filtered_lines)
|
| 1010 |
+
|
| 1011 |
+
response = content.strip() + "\n\n"
|
| 1012 |
+
|
| 1013 |
+
# Handle visualizations
|
| 1014 |
+
if "output_image_paths" in result and result["output_image_paths"]:
|
| 1015 |
+
response += "### Visualizations\n\n"
|
| 1016 |
+
for img_path in result["output_image_paths"]:
|
| 1017 |
+
try:
|
| 1018 |
+
full_path = os.path.join(st.session_state.images_dir, img_path)
|
| 1019 |
+
with open(full_path, 'rb') as f:
|
| 1020 |
+
fig = pickle.load(f)
|
| 1021 |
+
|
| 1022 |
+
# Convert plotly figure to image
|
| 1023 |
+
img_bytes = BytesIO()
|
| 1024 |
+
fig.update_layout(width=800, height=500)
|
| 1025 |
+
pio.write_image(fig, img_bytes, format='png')
|
| 1026 |
+
img_bytes.seek(0)
|
| 1027 |
+
|
| 1028 |
+
# Convert to base64 for markdown image
|
| 1029 |
+
b64_img = base64.b64encode(img_bytes.read()).decode()
|
| 1030 |
+
response += f"\n\n"
|
| 1031 |
+
except Exception as e:
|
| 1032 |
+
response += f"Error loading visualization: {str(e)}\n\n"
|
| 1033 |
+
|
| 1034 |
+
return response
|
| 1035 |
+
|
| 1036 |
+
except Exception as e:
|
| 1037 |
+
import traceback
|
| 1038 |
+
print(f"Error in chat_with_workflow: {str(e)}")
|
| 1039 |
+
print(traceback.format_exc())
|
| 1040 |
+
return f"Error executing workflow: {str(e)}"
|
| 1041 |
+
|
| 1042 |
+
def auto_generate_dashboard(dataset_info):
|
| 1043 |
+
"""Generate an automatic dashboard with four plots"""
|
| 1044 |
+
|
| 1045 |
+
if not dataset_info:
|
| 1046 |
+
return "Please upload a dataset first.", [None, None, None, None]
|
| 1047 |
+
|
| 1048 |
+
prompt = """
|
| 1049 |
+
You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends).
|
| 1050 |
+
|
| 1051 |
+
Use plotly and store the plots in a list named plotly_figures.
|
| 1052 |
+
|
| 1053 |
+
Include multivariate plots using color/size/facets when helpful.
|
| 1054 |
+
"""
|
| 1055 |
+
|
| 1056 |
+
state = AgentState(
|
| 1057 |
+
messages=[HumanMessage(content=prompt)],
|
| 1058 |
+
input_data=dataset_info,
|
| 1059 |
+
intermediate_outputs=[],
|
| 1060 |
+
current_variables=st.session_state.persistent_vars,
|
| 1061 |
+
output_image_paths=[]
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
result = chain.invoke(state)
|
| 1065 |
+
figures = []
|
| 1066 |
+
|
| 1067 |
+
if "output_image_paths" in result:
|
| 1068 |
+
for img_path in result["output_image_paths"][:4]:
|
| 1069 |
+
try:
|
| 1070 |
+
full_path = os.path.join(st.session_state.images_dir, img_path)
|
| 1071 |
+
with open(full_path, 'rb') as f:
|
| 1072 |
+
fig = pickle.load(f)
|
| 1073 |
+
figures.append(fig)
|
| 1074 |
+
except Exception as e:
|
| 1075 |
+
print(f"Error loading figure: {e}")
|
| 1076 |
+
|
| 1077 |
+
while len(figures) < 4:
|
| 1078 |
+
figures.append(None)
|
| 1079 |
+
|
| 1080 |
+
st.session_state.dashboard_plots = figures
|
| 1081 |
+
return "Dashboard generated!", figures
|
| 1082 |
+
|
| 1083 |
+
def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col):
|
| 1084 |
+
"""Generate custom plots based on user-selected columns"""
|
| 1085 |
+
|
| 1086 |
+
if not dataset_info or not x_col or not y_col:
|
| 1087 |
+
return [None, None, None]
|
| 1088 |
+
|
| 1089 |
+
prompt = f"""
|
| 1090 |
+
You are a data visualization expert.
|
| 1091 |
+
|
| 1092 |
+
Create 3 insightful visualizations using Plotly based on:
|
| 1093 |
+
|
| 1094 |
+
- X-axis: {x_col}
|
| 1095 |
+
- Y-axis: {y_col}
|
| 1096 |
+
- Facet (optional): {facet_col if facet_col != 'None' else 'None'}
|
| 1097 |
+
|
| 1098 |
+
Try to find interesting relationships, trends, or clusters using appropriate chart types.
|
| 1099 |
+
|
| 1100 |
+
Use `plotly_figures` list and avoid using fig.show().
|
| 1101 |
+
"""
|
| 1102 |
+
|
| 1103 |
+
state = AgentState(
|
| 1104 |
+
messages=[HumanMessage(content=prompt)],
|
| 1105 |
+
input_data=dataset_info,
|
| 1106 |
+
intermediate_outputs=[],
|
| 1107 |
+
current_variables=st.session_state.persistent_vars,
|
| 1108 |
+
output_image_paths=[]
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
result = chain.invoke(state)
|
| 1112 |
+
figures = []
|
| 1113 |
+
|
| 1114 |
+
if "output_image_paths" in result:
|
| 1115 |
+
for img_path in result["output_image_paths"][:3]:
|
| 1116 |
+
try:
|
| 1117 |
+
full_path = os.path.join(st.session_state.images_dir, img_path)
|
| 1118 |
+
with open(full_path, 'rb') as f:
|
| 1119 |
+
fig = pickle.load(f)
|
| 1120 |
+
figures.append(fig)
|
| 1121 |
+
except Exception as e:
|
| 1122 |
+
print(f"Error loading figure: {e}")
|
| 1123 |
+
|
| 1124 |
+
while len(figures) < 3:
|
| 1125 |
+
figures.append(None)
|
| 1126 |
+
return figures
|
| 1127 |
+
|
| 1128 |
+
def remove_plot(index):
|
| 1129 |
+
"""Remove a plot from the dashboard"""
|
| 1130 |
+
if 0 <= index < len(st.session_state.dashboard_plots):
|
| 1131 |
+
st.session_state.dashboard_plots[index] = None
|
| 1132 |
+
|
| 1133 |
+
def respond(message):
|
| 1134 |
+
"""Handle chat message response"""
|
| 1135 |
+
if not st.session_state.dataset_metadata_list:
|
| 1136 |
+
bot_message = "Please upload at least one dataset before asking questions."
|
| 1137 |
+
else:
|
| 1138 |
+
bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list)
|
| 1139 |
+
|
| 1140 |
+
st.session_state.chat_history.append((message, bot_message))
|
| 1141 |
+
st.rerun()
|
| 1142 |
+
|
| 1143 |
+
def save_plot_to_dashboard(plot_index):
|
| 1144 |
+
"""Callback for the Add Plot button"""
|
| 1145 |
+
for i in range(len(st.session_state.dashboard_plots)):
|
| 1146 |
+
if st.session_state.dashboard_plots[i] is None:
|
| 1147 |
+
# Found an empty slot
|
| 1148 |
+
st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index]
|
| 1149 |
+
return
|
| 1150 |
+
|
| 1151 |
+
# New function to check if settings are valid
|
| 1152 |
+
def is_settings_valid():
|
| 1153 |
+
"""Check if API key and model are configured"""
|
| 1154 |
+
return st.session_state.api_key != ""
|
| 1155 |
+
|
| 1156 |
+
# Streamlit UI with left panel settings
|
| 1157 |
+
st.set_page_config(page_title="Data Analysis Assistant", layout="wide")
|
| 1158 |
+
|
| 1159 |
+
# Create side panel for settings
|
| 1160 |
+
with st.sidebar:
|
| 1161 |
+
st.header("Settings")
|
| 1162 |
+
st.info("⚠️ Configure your API settings before uploading data")
|
| 1163 |
+
|
| 1164 |
+
# AI Provider selection
|
| 1165 |
+
provider = st.radio("Select AI Provider",
|
| 1166 |
+
options=["OpenAI", "Groq"],
|
| 1167 |
+
index=0 if st.session_state.ai_provider == "openai" else 1,
|
| 1168 |
+
horizontal=True)
|
| 1169 |
+
|
| 1170 |
+
# Update session state based on selection
|
| 1171 |
+
st.session_state.ai_provider = provider.lower()
|
| 1172 |
+
|
| 1173 |
+
# API Key input
|
| 1174 |
+
api_key = st.text_input("Enter API Key",
|
| 1175 |
+
value=st.session_state.api_key,
|
| 1176 |
+
type="password",
|
| 1177 |
+
help="Your API key for the selected provider")
|
| 1178 |
+
|
| 1179 |
+
# Display different model options based on provider
|
| 1180 |
+
if st.session_state.ai_provider == "openai":
|
| 1181 |
+
model_options = OPENAI_MODELS
|
| 1182 |
+
model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable."
|
| 1183 |
+
else: # groq
|
| 1184 |
+
model_options = GROQ_MODELS
|
| 1185 |
+
model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest."
|
| 1186 |
+
|
| 1187 |
+
# Model selection
|
| 1188 |
+
selected_model = st.selectbox("Select Model",
|
| 1189 |
+
options=model_options,
|
| 1190 |
+
index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0,
|
| 1191 |
+
help=model_help)
|
| 1192 |
+
|
| 1193 |
+
# Save button
|
| 1194 |
+
if st.button("Save Settings"):
|
| 1195 |
+
st.session_state.api_key = api_key
|
| 1196 |
+
st.session_state.selected_model = selected_model
|
| 1197 |
+
|
| 1198 |
+
# Test the API key and model
|
| 1199 |
+
try:
|
| 1200 |
+
# Initialize LLM using the provided settings
|
| 1201 |
+
test_llm = initialize_llm()
|
| 1202 |
+
if test_llm:
|
| 1203 |
+
st.success(f"✅ Successfully configured {provider} with model: {selected_model}")
|
| 1204 |
+
else:
|
| 1205 |
+
st.error("Failed to initialize the AI provider. Please check your API key and model selection.")
|
| 1206 |
+
except Exception as e:
|
| 1207 |
+
st.error(f"Error testing settings: {str(e)}")
|
| 1208 |
+
|
| 1209 |
+
# Display current settings
|
| 1210 |
+
st.subheader("Current Settings")
|
| 1211 |
+
settings_info = f"""
|
| 1212 |
+
- **Provider**: {st.session_state.ai_provider.upper()}
|
| 1213 |
+
- **Model**: {st.session_state.selected_model}
|
| 1214 |
+
- **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'}
|
| 1215 |
+
"""
|
| 1216 |
+
st.markdown(settings_info)
|
| 1217 |
+
|
| 1218 |
+
# Provider-specific information
|
| 1219 |
+
if st.session_state.ai_provider == "openai":
|
| 1220 |
+
with st.expander("OpenAI Models Info"):
|
| 1221 |
+
st.info("""
|
| 1222 |
+
- **GPT-4**: Most powerful model, best for complex analysis and detailed explanations
|
| 1223 |
+
- **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities
|
| 1224 |
+
- **GPT-4-Mini**: Economical option with good performance for standard tasks
|
| 1225 |
+
- **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization
|
| 1226 |
+
""")
|
| 1227 |
+
else:
|
| 1228 |
+
with st.expander("Groq Models Info"):
|
| 1229 |
+
st.info("""
|
| 1230 |
+
- **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis
|
| 1231 |
+
- **gemma2-9b-it**: Good balance of speed and capabilities
|
| 1232 |
+
- **llama-3-8b-8192**: Fastest option for basic analysis tasks
|
| 1233 |
+
""")
|
| 1234 |
+
|
| 1235 |
+
# Integration instructions
|
| 1236 |
+
with st.expander("How to get API Keys"):
|
| 1237 |
+
if st.session_state.ai_provider == "openai":
|
| 1238 |
+
st.markdown("""
|
| 1239 |
+
### Getting an OpenAI API Key
|
| 1240 |
+
|
| 1241 |
+
1. Go to [OpenAI's platform](https://platform.openai.com)
|
| 1242 |
+
2. Sign up or log in to your account
|
| 1243 |
+
3. Navigate to the API section
|
| 1244 |
+
4. Create a new API key
|
| 1245 |
+
5. Copy the key and paste it above
|
| 1246 |
+
|
| 1247 |
+
Note: OpenAI API usage incurs charges based on tokens used.
|
| 1248 |
+
""")
|
| 1249 |
+
else:
|
| 1250 |
+
st.markdown("""
|
| 1251 |
+
### Getting a Groq API Key
|
| 1252 |
+
|
| 1253 |
+
1. Go to [Groq's website](https://console.groq.com/keys)
|
| 1254 |
+
2. Sign up or log in to your account
|
| 1255 |
+
3. Navigate to API Keys section
|
| 1256 |
+
4. Create a new API key
|
| 1257 |
+
5. Copy the key and paste it above
|
| 1258 |
+
|
| 1259 |
+
Note: Check Groq's pricing page for current rates.
|
| 1260 |
+
""")
|
| 1261 |
+
|
| 1262 |
+
# Main content
|
| 1263 |
+
st.title("Data Analysis Assistant")
|
| 1264 |
+
st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.")
|
| 1265 |
+
|
| 1266 |
+
# Check if settings are valid before showing tabs
|
| 1267 |
+
if not is_settings_valid():
|
| 1268 |
+
st.warning("⚠️ Please configure your API key in the sidebar settings panel before proceeding.")
|
| 1269 |
+
st.stop()
|
| 1270 |
+
|
| 1271 |
+
# Create main tabs - only if settings are valid
|
| 1272 |
+
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report"])
|
| 1273 |
+
|
| 1274 |
+
with tab1:
|
| 1275 |
+
st.header("Upload Datasets")
|
| 1276 |
+
uploaded_files = st.file_uploader("Upload CSV or Excel Files",
|
| 1277 |
+
accept_multiple_files=True,
|
| 1278 |
+
type=['csv', 'xlsx', 'xls'])
|
| 1279 |
+
|
| 1280 |
+
if uploaded_files and st.button("Process Uploaded Files"):
|
| 1281 |
+
with st.spinner("Processing files..."):
|
| 1282 |
+
preview, metadata_list, columns = process_file_upload(uploaded_files)
|
| 1283 |
+
st.session_state.columns = columns
|
| 1284 |
+
|
| 1285 |
+
# Display basic information about processed files
|
| 1286 |
+
st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)")
|
| 1287 |
+
|
| 1288 |
+
# Show detailed preview for each dataset
|
| 1289 |
+
st.subheader("Dataset Previews")
|
| 1290 |
+
|
| 1291 |
+
for dataset_name, df in st.session_state.in_memory_datasets.items():
|
| 1292 |
+
with st.expander(f"Preview: {dataset_name}"):
|
| 1293 |
+
# Display dataset info
|
| 1294 |
+
st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}")
|
| 1295 |
+
|
| 1296 |
+
# Display column information
|
| 1297 |
+
col_info = pd.DataFrame({
|
| 1298 |
+
'Column Name': df.columns,
|
| 1299 |
+
'Data Type': df.dtypes.astype(str),
|
| 1300 |
+
'Non-Null Count': df.count().values,
|
| 1301 |
+
'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns]
|
| 1302 |
+
})
|
| 1303 |
+
|
| 1304 |
+
# Show column information in a compact table
|
| 1305 |
+
st.write("**Column Information:**")
|
| 1306 |
+
st.dataframe(col_info, use_container_width=True)
|
| 1307 |
+
|
| 1308 |
+
# Show actual data preview
|
| 1309 |
+
st.write("**Data Preview (First 10 rows):**")
|
| 1310 |
+
st.dataframe(df.head(10), use_container_width=True)
|
| 1311 |
+
|
| 1312 |
+
# Provide hint for the next steps
|
| 1313 |
+
st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.")
|
| 1314 |
+
|
| 1315 |
+
with tab2:
|
| 1316 |
+
st.header("Data Cleaning")
|
| 1317 |
+
|
| 1318 |
+
if 'cleaning_done' not in st.session_state:
|
| 1319 |
+
st.session_state.cleaning_done = False
|
| 1320 |
+
|
| 1321 |
+
if 'cleaned_datasets' not in st.session_state:
|
| 1322 |
+
st.session_state.cleaned_datasets = {}
|
| 1323 |
+
|
| 1324 |
+
if 'cleaning_summaries' not in st.session_state:
|
| 1325 |
+
st.session_state.cleaning_summaries = {}
|
| 1326 |
+
|
| 1327 |
+
if st.session_state.get("in_memory_datasets"):
|
| 1328 |
+
if not st.session_state.cleaning_done:
|
| 1329 |
+
if st.button("Run Data Cleaning"):
|
| 1330 |
+
with st.spinner("Running LLM-assisted cleaning..."):
|
| 1331 |
+
for name, df in st.session_state.in_memory_datasets.items():
|
| 1332 |
+
raw_df = df.copy()
|
| 1333 |
+
df_std = standard_clean(raw_df.copy())
|
| 1334 |
+
suggestions = llm_suggest_cleaning(df_std.copy())
|
| 1335 |
+
df_clean = apply_suggestions(df_std.copy(), suggestions)
|
| 1336 |
+
st.session_state.cleaned_datasets[name] = df_clean
|
| 1337 |
+
st.session_state.cleaning_summaries[name] = suggestions
|
| 1338 |
+
st.session_state.cleaning_done = True
|
| 1339 |
+
st.rerun()
|
| 1340 |
+
else:
|
| 1341 |
+
st.info("Click Run Data Cleaning to clean your datasets using the LLM.")
|
| 1342 |
+
else:
|
| 1343 |
+
for name, df_clean in st.session_state.cleaned_datasets.items():
|
| 1344 |
+
raw_df = st.session_state.in_memory_datasets[name]
|
| 1345 |
+
|
| 1346 |
+
st.subheader(f"Dataset: {name}")
|
| 1347 |
+
col1, col2 = st.columns(2)
|
| 1348 |
+
|
| 1349 |
+
with col1:
|
| 1350 |
+
st.markdown("Original Data (First 5 Rows)")
|
| 1351 |
+
st.dataframe(raw_df.head())
|
| 1352 |
+
|
| 1353 |
+
with col2:
|
| 1354 |
+
st.markdown("Cleaned Data (First 5 Rows)")
|
| 1355 |
+
st.dataframe(df_clean.head())
|
| 1356 |
+
|
| 1357 |
+
st.markdown("Summary of Cleaning Actions")
|
| 1358 |
+
suggestions = st.session_state.cleaning_summaries[name]
|
| 1359 |
+
summary_text = ""
|
| 1360 |
+
|
| 1361 |
+
if suggestions:
|
| 1362 |
+
for key, value in suggestions.items():
|
| 1363 |
+
summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n"
|
| 1364 |
+
st.markdown(summary_text)
|
| 1365 |
+
|
| 1366 |
+
st.markdown("Refine the Cleaning (Natural Language Instructions)")
|
| 1367 |
+
user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'",
|
| 1368 |
+
key=f"user_input_{name}")
|
| 1369 |
+
|
| 1370 |
+
if f'corrections_{name}' not in st.session_state:
|
| 1371 |
+
st.session_state[f'corrections_{name}'] = []
|
| 1372 |
+
|
| 1373 |
+
if st.button("Apply Correction", key=f'apply_correction_{name}'):
|
| 1374 |
+
if user_input.strip():
|
| 1375 |
+
correction_prompt = f"""
|
| 1376 |
+
You are a data cleaning expert. Below is a previously cleaned dataset with these actions:
|
| 1377 |
+
|
| 1378 |
+
{summary_text}
|
| 1379 |
+
|
| 1380 |
+
The user now wants the following additional instruction:
|
| 1381 |
+
\"{user_input.strip()}\"
|
| 1382 |
+
|
| 1383 |
+
Write only the Python code that modifies the pandas DataFrame `df` accordingly.
|
| 1384 |
+
Do not include explanations or markdown.
|
| 1385 |
+
"""
|
| 1386 |
+
correction_code = query_openai(correction_prompt)
|
| 1387 |
+
|
| 1388 |
+
try:
|
| 1389 |
+
df = st.session_state.cleaned_datasets[name].copy()
|
| 1390 |
+
local_vars = {"df": df}
|
| 1391 |
+
exec(correction_code, {}, local_vars)
|
| 1392 |
+
df_updated = local_vars["df"]
|
| 1393 |
+
|
| 1394 |
+
st.session_state.cleaned_datasets[name] = df_updated
|
| 1395 |
+
st.session_state[f'corrections_{name}'].append((user_input, correction_code))
|
| 1396 |
+
st.success("Correction applied.")
|
| 1397 |
+
st.rerun()
|
| 1398 |
+
|
| 1399 |
+
except Exception as e:
|
| 1400 |
+
st.error(f"Failed to apply correction: {str(e)}")
|
| 1401 |
+
|
| 1402 |
+
if st.session_state[f'corrections_{name}']:
|
| 1403 |
+
st.markdown("Applied Corrections")
|
| 1404 |
+
for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']):
|
| 1405 |
+
st.markdown(f"**Instruction:** {msg}")
|
| 1406 |
+
st.code(code, language='python')
|
| 1407 |
+
|
| 1408 |
+
col1, col2 = st.columns([1, 2])
|
| 1409 |
+
with col1:
|
| 1410 |
+
if st.button("Reset Cleaning and Re-run"):
|
| 1411 |
+
st.session_state.cleaning_done = False
|
| 1412 |
+
st.rerun()
|
| 1413 |
+
|
| 1414 |
+
with col2:
|
| 1415 |
+
if st.button("Finalize and Proceed to Visualizations"):
|
| 1416 |
+
st.session_state.cleaning_finalized = True
|
| 1417 |
+
st.rerun()
|
| 1418 |
+
else:
|
| 1419 |
+
st.info("Please upload and process datasets first.")
|
| 1420 |
+
|
| 1421 |
+
with tab3:
|
| 1422 |
+
st.header("Chat with AI Assistant")
|
| 1423 |
+
|
| 1424 |
+
st.markdown("""
|
| 1425 |
+
## Example Questions
|
| 1426 |
+
- "What analysis can you perform on this dataset?"
|
| 1427 |
+
- "Show me basic statistics for all columns"
|
| 1428 |
+
- "Create a correlation heatmap"
|
| 1429 |
+
- "Plot the distribution of a specific column"
|
| 1430 |
+
- "What is the relationship between two columns?"
|
| 1431 |
+
""")
|
| 1432 |
+
|
| 1433 |
+
# Display chat history
|
| 1434 |
+
for exchange in st.session_state.chat_history:
|
| 1435 |
+
with st.chat_message("user"):
|
| 1436 |
+
st.write(exchange[0])
|
| 1437 |
+
with st.chat_message("assistant"):
|
| 1438 |
+
st.write(exchange[1])
|
| 1439 |
+
|
| 1440 |
+
# Chat input
|
| 1441 |
+
if prompt := st.chat_input("Your question"):
|
| 1442 |
+
with st.spinner("Thinking..."):
|
| 1443 |
+
respond(prompt)
|
| 1444 |
+
|
| 1445 |
+
with tab4:
|
| 1446 |
+
st.header("Auto Dashboard Generator")
|
| 1447 |
+
|
| 1448 |
+
# Dashboard controls
|
| 1449 |
+
dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title")
|
| 1450 |
+
|
| 1451 |
+
col1, col2 = st.columns(2)
|
| 1452 |
+
|
| 1453 |
+
with col1:
|
| 1454 |
+
if st.button("Generate Suggested Dashboard (Auto)"):
|
| 1455 |
+
with st.spinner("Generating dashboard..."):
|
| 1456 |
+
message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list)
|
| 1457 |
+
st.success(message)
|
| 1458 |
+
|
| 1459 |
+
with col2:
|
| 1460 |
+
if st.button("Refresh Column Options"):
|
| 1461 |
+
st.session_state.columns = get_columns()
|
| 1462 |
+
st.rerun()
|
| 1463 |
+
|
| 1464 |
+
# Dashboard display
|
| 1465 |
+
st.subheader("Dashboard")
|
| 1466 |
+
|
| 1467 |
+
# Row 1
|
| 1468 |
+
col1, col2 = st.columns(2)
|
| 1469 |
+
|
| 1470 |
+
with col1:
|
| 1471 |
+
if st.session_state.dashboard_plots[0]:
|
| 1472 |
+
st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True)
|
| 1473 |
+
if st.button("Remove Plot 1"):
|
| 1474 |
+
remove_plot(0)
|
| 1475 |
+
st.rerun()
|
| 1476 |
+
|
| 1477 |
+
with col2:
|
| 1478 |
+
if st.session_state.dashboard_plots[1]:
|
| 1479 |
+
st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True)
|
| 1480 |
+
if st.button("Remove Plot 2"):
|
| 1481 |
+
remove_plot(1)
|
| 1482 |
+
st.rerun()
|
| 1483 |
+
|
| 1484 |
+
# Row 2
|
| 1485 |
+
col3, col4 = st.columns(2)
|
| 1486 |
+
|
| 1487 |
+
with col3:
|
| 1488 |
+
if st.session_state.dashboard_plots[2]:
|
| 1489 |
+
st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True)
|
| 1490 |
+
if st.button("Remove Plot 3"):
|
| 1491 |
+
remove_plot(2)
|
| 1492 |
+
st.rerun()
|
| 1493 |
+
|
| 1494 |
+
with col4:
|
| 1495 |
+
if st.session_state.dashboard_plots[3]:
|
| 1496 |
+
st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True)
|
| 1497 |
+
if st.button("Remove Plot 4"):
|
| 1498 |
+
remove_plot(3)
|
| 1499 |
+
st.rerun()
|
| 1500 |
+
|
| 1501 |
+
# Custom plot generator
|
| 1502 |
+
st.subheader("Custom Plot Generator")
|
| 1503 |
+
|
| 1504 |
+
# Column selection
|
| 1505 |
+
col1, col2, col3 = st.columns(3)
|
| 1506 |
+
|
| 1507 |
+
with col1:
|
| 1508 |
+
x_axis = st.selectbox("X-axis Column", options=st.session_state.columns)
|
| 1509 |
+
|
| 1510 |
+
with col2:
|
| 1511 |
+
y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns)
|
| 1512 |
+
|
| 1513 |
+
with col3:
|
| 1514 |
+
facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns)
|
| 1515 |
+
|
| 1516 |
+
if st.button("Generate Custom Visualizations"):
|
| 1517 |
+
with st.spinner("Generating custom visualizations..."):
|
| 1518 |
+
custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet)
|
| 1519 |
+
# Store plots in session state
|
| 1520 |
+
for i, plot in enumerate(custom_plots):
|
| 1521 |
+
if plot:
|
| 1522 |
+
st.session_state.custom_plots_to_save[i] = plot
|
| 1523 |
+
|
| 1524 |
+
# Display custom plots with add buttons
|
| 1525 |
+
for i, plot in enumerate(custom_plots):
|
| 1526 |
+
if plot:
|
| 1527 |
+
st.plotly_chart(plot, use_container_width=True)
|
| 1528 |
+
st.button(
|
| 1529 |
+
f"Add Plot {i+1} to Dashboard",
|
| 1530 |
+
key=f"add_plot_{i}",
|
| 1531 |
+
on_click=save_plot_to_dashboard,
|
| 1532 |
+
args=(i,)
|
| 1533 |
+
)
|
| 1534 |
+
|
| 1535 |
+
with tab5:
|
| 1536 |
+
st.header("Generate Analysis Report")
|
| 1537 |
+
|
| 1538 |
+
st.markdown("""
|
| 1539 |
+
Generate a PDF report containing:
|
| 1540 |
+
- Dashboard visualizations
|
| 1541 |
+
- Chat conversation history
|
| 1542 |
+
""")
|
| 1543 |
+
|
| 1544 |
+
report_title = st.text_input("Report Title (Optional)", "Data Analysis Report")
|
| 1545 |
+
|
| 1546 |
+
if st.button("Generate PDF Report"):
|
| 1547 |
+
with st.spinner("Generating report..."):
|
| 1548 |
+
pdf_data = generate_enhanced_pdf_report()
|
| 1549 |
+
if pdf_data:
|
| 1550 |
+
# Create download button for PDF
|
| 1551 |
+
b64_pdf = base64.b64encode(pdf_data).decode('utf-8')
|
| 1552 |
+
# Create download link
|
| 1553 |
+
pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>'
|
| 1554 |
+
st.markdown("### Your report is ready!")
|
| 1555 |
+
st.markdown(pdf_download_link, unsafe_allow_html=True)
|
| 1556 |
+
# Preview option (simplified)
|
| 1557 |
+
with st.expander("Preview Report"):
|
| 1558 |
+
st.warning("PDF preview is not available in Streamlit, please download the report to view it.")
|
| 1559 |
+
else:
|
| 1560 |
+
st.error("Failed to generate the report. Please try again.")
|
| 1561 |
+
|
| 1562 |
+
# Cleanup on app exit
|
| 1563 |
+
def cleanup():
|
| 1564 |
+
try:
|
| 1565 |
+
shutil.rmtree(st.session_state.temp_dir)
|
| 1566 |
+
print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
|
| 1567 |
+
except Exception as e:
|
| 1568 |
+
print(f"Error cleaning up: {e}")
|
| 1569 |
+
|
| 1570 |
+
import atexit
|
| 1571 |
+
atexit.register(cleanup)
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
plotly
|
| 5 |
+
langchain
|
| 6 |
+
langgraph==0.2.74
|
| 7 |
+
langchain-core
|
| 8 |
+
langchain-groq
|
| 9 |
+
langchain_openai
|
| 10 |
+
langchain-experimental
|
| 11 |
+
reportlab
|
| 12 |
+
Pillow
|
| 13 |
+
scikit-learn
|
| 14 |
+
tabulate
|
| 15 |
+
kaleido
|