srivatsavdamaraju commited on
Commit
cd11586
·
verified ·
1 Parent(s): 4163b9d

Upload 190 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
agent_duckdb.db ADDED
Binary file (12.3 kB). View file
 
app.ipynb CHANGED
@@ -1,127 +1,127 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "578b4ea2",
7
- "metadata": {
8
- "vscode": {
9
- "languageId": "plaintext"
10
- }
11
- },
12
- "outputs": [],
13
- "source": [
14
- "import json\n",
15
- "import warnings\n",
16
- "from fastapi import FastAPI, Request\n",
17
- "from fastapi.responses import JSONResponse\n",
18
- "from fastapi.middleware.cors import CORSMiddleware\n",
19
- "import os \n",
20
- "\n",
21
- "# Import routes\n",
22
- "# from routes import collections, ingestion, sessions, chat, system\n",
23
- "from Redis.sessions import Redis_session_router\n",
24
- "from Redis.sessions_new import redis_session_route_new\n",
25
- "from Redis.sessions_old import Redis_session_router_old\n",
26
- "# from Routes.main_chat_bot import main_chatbot_route\n",
27
- "from Routes.generate_report import Report_Generation_Router\n",
28
- "# from report_generation.report_generation import Report_Generation_Router_v2\n",
29
- "from s3.file_insertion_s3 import s3_bucket_router\n",
30
- "from s3.rough2 import test_router\n",
31
- "from backend.main import login_apis_router\n",
32
- "from s3.r3 import s3_bucket_router1\n",
33
- "from s3_viewer_backend.s3_viewer import s3_viewer_router \n",
34
- "# from vector_db.retrival_qa_agent import RetrievalQA_router\n",
35
- "from main_chat_2.main_chat import main_chat_router_v4\n",
36
- "from main_chat5.main_chat_router import main_chat_router_v5\n",
37
- "from report_generation.ppt_v2 import Report_Generation_Router_ppt_v2\n",
38
- "from report_generation.pdf_v3 import Report_Generation_Router_pdf_v2\n",
39
- "\n",
40
- "\n",
41
- "\n",
42
- "from Routes.main_agent_chat_bot_v2 import main_chatbot_route_v2\n",
43
- "from vector_db import qdrant_crud\n",
44
- "\n",
45
- "from vector_db.qdrant_crud import Qdrant_router\n",
46
- "from agent_tools.autovis_tool import Autoviz_router\n",
47
- "# from main_chat_2.main_chat import main_chat_router_v3\n",
48
- "from DB_store_backup.agentic_context_convoid_management import Db_store_router\n",
49
- "# Suppress warnings\n",
50
- "warnings.filterwarnings(\"ignore\", message=\"Qdrant client version.*is incompatible.*\")\n",
51
- "\n",
52
- "app = FastAPI(title=\"Combined AI Agent with Qdrant Collections and Redis Session Management\")\n",
53
- "\n",
54
- "# Add CORS middleware\n",
55
- "app.add_middleware(\n",
56
- " CORSMiddleware,\n",
57
- " allow_origins=[\"*\"],\n",
58
- " allow_credentials=True,\n",
59
- " allow_methods=[\"*\"],\n",
60
- " allow_headers=[\"*\"],\n",
61
- ")\n",
62
- "\n",
63
- "print([route.path for route in app.routes])\n",
64
- "# Include routers\n",
65
- "\n",
66
- "# app.include_router(main_chatbot_route)\n",
67
- "app.include_router(Report_Generation_Router)\n",
68
- "app.include_router(Redis_session_router)\n",
69
- "app.include_router(redis_session_route_new)\n",
70
- "# app.include_router(Redis_session_router_old)\n",
71
- "app.include_router(Db_store_router)\n",
72
- "app.include_router(s3_bucket_router)\n",
73
- "app.include_router(Autoviz_router)\n",
74
- "# app.include_router(main_chat_router_v3)\n",
75
- "app.include_router(s3_viewer_router)\n",
76
- "\n",
77
- "app.include_router(s3_bucket_router1)\n",
78
- "app.include_router(main_chat_router_v4)\n",
79
- "app.include_router(test_router)\n",
80
- "app.include_router(Qdrant_router)\n",
81
- "# app.include_router(RetrievalQA_router)\n",
82
- "app.include_router(main_chatbot_route_v2)\n",
83
- "# app.include_router(Report_Generation_Router_v2)\n",
84
- "app.include_router(Report_Generation_Router_ppt_v2)\n",
85
- "app.include_router(Report_Generation_Router_pdf_v2)\n",
86
- "\n",
87
- "app.include_router(main_chat_router_v5)\n",
88
- "#==================================login and user management routes==================================\n",
89
- "app.include_router(login_apis_router)\n",
90
- "\n",
91
- "\n",
92
- "# ------------------- MIDDLEWARE -------------------\n",
93
- "\n",
94
- "@app.middleware(\"http\")\n",
95
- "async def add_success_flag(request: Request, call_next):\n",
96
- " response = await call_next(request)\n",
97
- "\n",
98
- " # Only modify JSON responses\n",
99
- " if \"application/json\" in response.headers.get(\"content-type\", \"\"):\n",
100
- " try:\n",
101
- " body = b\"\".join([chunk async for chunk in response.body_iterator])\n",
102
- " data = json.loads(body.decode())\n",
103
- "\n",
104
- " # Add success flag\n",
105
- " data[\"success\"] = 200 <= response.status_code < 300\n",
106
- "\n",
107
- " # Build new JSONResponse (auto handles Content-Length)\n",
108
- " response = JSONResponse(\n",
109
- " content=data,\n",
110
- " status_code=response.status_code,\n",
111
- " headers={k: v for k, v in response.headers.items() if k.lower() != \"content-length\"},\n",
112
- " )\n",
113
- " except Exception:\n",
114
- " # fallback if response is not JSON parseable\n",
115
- " pass\n",
116
- " return response"
117
- ]
118
- }
119
- ],
120
- "metadata": {
121
- "language_info": {
122
- "name": "python"
123
- }
124
- },
125
- "nbformat": 4,
126
- "nbformat_minor": 5
127
- }
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "578b4ea2",
7
+ "metadata": {
8
+ "vscode": {
9
+ "languageId": "plaintext"
10
+ }
11
+ },
12
+ "outputs": [],
13
+ "source": [
14
+ "import json\n",
15
+ "import warnings\n",
16
+ "from fastapi import FastAPI, Request\n",
17
+ "from fastapi.responses import JSONResponse\n",
18
+ "from fastapi.middleware.cors import CORSMiddleware\n",
19
+ "import os \n",
20
+ "\n",
21
+ "# Import routes\n",
22
+ "# from routes import collections, ingestion, sessions, chat, system\n",
23
+ "from Redis.sessions import Redis_session_router\n",
24
+ "from Redis.sessions_new import redis_session_route_new\n",
25
+ "from Redis.sessions_old import Redis_session_router_old\n",
26
+ "# from Routes.main_chat_bot import main_chatbot_route\n",
27
+ "from Routes.generate_report import Report_Generation_Router\n",
28
+ "# from report_generation.report_generation import Report_Generation_Router_v2\n",
29
+ "from s3.file_insertion_s3 import s3_bucket_router\n",
30
+ "from s3.rough2 import test_router\n",
31
+ "from backend.main import login_apis_router\n",
32
+ "from s3.r3 import s3_bucket_router1\n",
33
+ "from s3_viewer_backend.s3_viewer import s3_viewer_router \n",
34
+ "# from vector_db.retrival_qa_agent import RetrievalQA_router\n",
35
+ "from main_chat_2.main_chat import main_chat_router_v4\n",
36
+ "from main_chat5.main_chat_router import main_chat_router_v5\n",
37
+ "from report_generation.ppt_v2 import Report_Generation_Router_ppt_v2\n",
38
+ "from report_generation.pdf_v3 import Report_Generation_Router_pdf_v2\n",
39
+ "\n",
40
+ "\n",
41
+ "\n",
42
+ "from Routes.main_agent_chat_bot_v2 import main_chatbot_route_v2\n",
43
+ "from vector_db import qdrant_crud\n",
44
+ "\n",
45
+ "from vector_db.qdrant_crud import Qdrant_router\n",
46
+ "from agent_tools.autovis_tool import Autoviz_router\n",
47
+ "# from main_chat_2.main_chat import main_chat_router_v3\n",
48
+ "from DB_store_backup.agentic_context_convoid_management import Db_store_router\n",
49
+ "# Suppress warnings\n",
50
+ "warnings.filterwarnings(\"ignore\", message=\"Qdrant client version.*is incompatible.*\")\n",
51
+ "\n",
52
+ "app = FastAPI(title=\"Combined AI Agent with Qdrant Collections and Redis Session Management\")\n",
53
+ "\n",
54
+ "# Add CORS middleware\n",
55
+ "app.add_middleware(\n",
56
+ " CORSMiddleware,\n",
57
+ " allow_origins=[\"*\"],\n",
58
+ " allow_credentials=True,\n",
59
+ " allow_methods=[\"*\"],\n",
60
+ " allow_headers=[\"*\"],\n",
61
+ ")\n",
62
+ "\n",
63
+ "print([route.path for route in app.routes])\n",
64
+ "# Include routers\n",
65
+ "\n",
66
+ "# app.include_router(main_chatbot_route)\n",
67
+ "app.include_router(Report_Generation_Router)\n",
68
+ "app.include_router(Redis_session_router)\n",
69
+ "app.include_router(redis_session_route_new)\n",
70
+ "# app.include_router(Redis_session_router_old)\n",
71
+ "app.include_router(Db_store_router)\n",
72
+ "app.include_router(s3_bucket_router)\n",
73
+ "app.include_router(Autoviz_router)\n",
74
+ "# app.include_router(main_chat_router_v3)\n",
75
+ "app.include_router(s3_viewer_router)\n",
76
+ "\n",
77
+ "app.include_router(s3_bucket_router1)\n",
78
+ "app.include_router(main_chat_router_v4)\n",
79
+ "app.include_router(test_router)\n",
80
+ "app.include_router(Qdrant_router)\n",
81
+ "# app.include_router(RetrievalQA_router)\n",
82
+ "app.include_router(main_chatbot_route_v2)\n",
83
+ "# app.include_router(Report_Generation_Router_v2)\n",
84
+ "app.include_router(Report_Generation_Router_ppt_v2)\n",
85
+ "app.include_router(Report_Generation_Router_pdf_v2)\n",
86
+ "\n",
87
+ "app.include_router(main_chat_router_v5)\n",
88
+ "#==================================login and user management routes==================================\n",
89
+ "app.include_router(login_apis_router)\n",
90
+ "\n",
91
+ "\n",
92
+ "# ------------------- MIDDLEWARE -------------------\n",
93
+ "\n",
94
+ "@app.middleware(\"http\")\n",
95
+ "async def add_success_flag(request: Request, call_next):\n",
96
+ " response = await call_next(request)\n",
97
+ "\n",
98
+ " # Only modify JSON responses\n",
99
+ " if \"application/json\" in response.headers.get(\"content-type\", \"\"):\n",
100
+ " try:\n",
101
+ " body = b\"\".join([chunk async for chunk in response.body_iterator])\n",
102
+ " data = json.loads(body.decode())\n",
103
+ "\n",
104
+ " # Add success flag\n",
105
+ " data[\"success\"] = 200 <= response.status_code < 300\n",
106
+ "\n",
107
+ " # Build new JSONResponse (auto handles Content-Length)\n",
108
+ " response = JSONResponse(\n",
109
+ " content=data,\n",
110
+ " status_code=response.status_code,\n",
111
+ " headers={k: v for k, v in response.headers.items() if k.lower() != \"content-length\"},\n",
112
+ " )\n",
113
+ " except Exception:\n",
114
+ " # fallback if response is not JSON parseable\n",
115
+ " pass\n",
116
+ " return response"
117
+ ]
118
+ }
119
+ ],
120
+ "metadata": {
121
+ "language_info": {
122
+ "name": "python"
123
+ }
124
+ },
125
+ "nbformat": 4,
126
+ "nbformat_minor": 5
127
+ }
app.py CHANGED
@@ -1,115 +1,115 @@
1
- import json
2
- import warnings
3
- from fastapi import FastAPI, Request
4
- from fastapi.responses import JSONResponse
5
- from fastapi.middleware.cors import CORSMiddleware
6
- import os
7
- import json
8
- import warnings
9
- from fastapi import FastAPI, Request
10
- from fastapi.responses import JSONResponse
11
- from fastapi.middleware.cors import CORSMiddleware
12
- import os
13
-
14
- # Import routes
15
- # from routes import collections, ingestion, sessions, chat, system
16
- from Redis.sessions import Redis_session_router
17
- from Redis.sessions_new import redis_session_route_new
18
- from Redis.sessions_old import Redis_session_router_old
19
- # from Routes.main_chat_bot import main_chatbot_route
20
- from Routes.generate_report import Report_Generation_Router
21
- # from report_generation.report_generation import Report_Generation_Router_v2
22
- from s3.file_insertion_s3 import s3_bucket_router
23
- from s3.rough2 import test_router
24
- from backend.main import login_apis_router
25
- # from s3.r3 import s3_bucket_router1
26
- from s3.r4 import s3_bucket_router1
27
- from s3_viewer_backend.s3_viewer import s3_viewer_router
28
- # from vector_db.retrival_qa_agent import RetrievalQA_router
29
- from main_chat_2.main_chat import main_chat_router_v4
30
- from main_chat5.main_chat_router import main_chat_router_v5
31
- from main_chat_v6.main_chat_router import main_chat_router_v6
32
- from report_generation.ppt_v2 import Report_Generation_Router_ppt_v2
33
- from report_generation.pdf_v3 import Report_Generation_Router_pdf_v2
34
-
35
-
36
-
37
- from Routes.main_agent_chat_bot_v2 import main_chatbot_route_v2
38
- # from vector_db import qdrant_crud
39
-
40
- # from vector_db.qdrant_crud import Qdrant_router
41
- from agent_tools.autovis_tool import Autoviz_router
42
- # from main_chat_2.main_chat import main_chat_router_v3
43
- from DB_store_backup.agentic_context_convoid_management import Db_store_router
44
- # Suppress warnings
45
- warnings.filterwarnings("ignore", message="Qdrant client version.*is incompatible.*")
46
-
47
- app = FastAPI(title="Combined AI Agent with Qdrant Collections and Redis Session Management")
48
-
49
- # Add CORS middleware
50
- app.add_middleware(
51
- CORSMiddleware,
52
- allow_origins=["*"],
53
- allow_credentials=True,
54
- allow_methods=["*"],
55
- allow_headers=["*"],
56
- )
57
-
58
- print([route.path for route in app.routes])
59
- # Include routers
60
-
61
- # app.include_router(main_chatbot_route)
62
- app.include_router(Report_Generation_Router)
63
- app.include_router(Redis_session_router)
64
- app.include_router(redis_session_route_new)
65
- # app.include_router(Redis_session_router_old)
66
- app.include_router(Db_store_router)
67
- app.include_router(s3_bucket_router)
68
- app.include_router(Autoviz_router)
69
- # app.include_router(main_chat_router_v3)
70
- app.include_router(s3_viewer_router)
71
-
72
- app.include_router(s3_bucket_router1)
73
- app.include_router(main_chat_router_v4)
74
- app.include_router(test_router)
75
-
76
- # app.include_router(RetrievalQA_router)
77
- app.include_router(main_chatbot_route_v2)
78
- # app.include_router(Report_Generation_Router_v2)
79
- app.include_router(Report_Generation_Router_ppt_v2)
80
- app.include_router(Report_Generation_Router_pdf_v2)
81
-
82
- app.include_router(main_chat_router_v5)
83
- app.include_router(main_chat_router_v6)
84
- #==================================login and user management routes==================================
85
- app.include_router(login_apis_router)
86
- #---------------------------Qdrant router -----------------------
87
-
88
- # app.include_router(Qdrant_router)
89
- # ------------------- MIDDLEWARE -------------------
90
-
91
- @app.middleware("http")
92
- async def add_success_flag(request: Request, call_next):
93
- response = await call_next(request)
94
-
95
- # Only modify JSON responses
96
- if "application/json" in response.headers.get("content-type", ""):
97
- try:
98
- body = b"".join([chunk async for chunk in response.body_iterator])
99
- data = json.loads(body.decode())
100
-
101
- # Add success flag
102
- data["success"] = 200 <= response.status_code < 300
103
-
104
- # Build new JSONResponse (auto handles Content-Length)
105
- response = JSONResponse(
106
- content=data,
107
- status_code=response.status_code,
108
- headers={k: v for k, v in response.headers.items() if k.lower() != "content-length"},
109
- )
110
- except Exception:
111
- # fallback if response is not JSON parseable
112
- pass
113
- return response
114
-
115
-
 
1
+ import json
2
+ import warnings
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.responses import JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ import os
7
+ import json
8
+ import warnings
9
+ from fastapi import FastAPI, Request
10
+ from fastapi.responses import JSONResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ import os
13
+
14
+ # Import routes
15
+ # from routes import collections, ingestion, sessions, chat, system
16
+ from Redis.sessions import Redis_session_router
17
+ from Redis.sessions_new import redis_session_route_new
18
+ from Redis.sessions_old import Redis_session_router_old
19
+ # from Routes.main_chat_bot import main_chatbot_route
20
+ from Routes.generate_report import Report_Generation_Router
21
+ # from report_generation.report_generation import Report_Generation_Router_v2
22
+ from s3.file_insertion_s3 import s3_bucket_router
23
+ from s3.rough2 import test_router
24
+ from backend.main import login_apis_router
25
+ # from s3.r3 import s3_bucket_router1
26
+ from s3.r6 import s3_bucket_router1
27
+ from s3_viewer_backend.s3_viewer import s3_viewer_router
28
+ # from vector_db.retrival_qa_agent import RetrievalQA_router
29
+ from main_chat_2.main_chat import main_chat_router_v4
30
+ from main_chat5.main_chat_router import main_chat_router_v5
31
+ from main_chat_v6.main_chat_router import main_chat_router_v6
32
+ from report_generation.ppt_v2 import Report_Generation_Router_ppt_v2
33
+ from report_generation.pdf_v3 import Report_Generation_Router_pdf_v2
34
+
35
+
36
+
37
+ from Routes.main_agent_chat_bot_v2 import main_chatbot_route_v2
38
+ # from vector_db import qdrant_crud
39
+
40
+ # from vector_db.qdrant_crud import Qdrant_router
41
+ from agent_tools.autovis_tool import Autoviz_router
42
+ # from main_chat_2.main_chat import main_chat_router_v3
43
+ from DB_store_backup.agentic_context_convoid_management import Db_store_router
44
+ # Suppress warnings
45
+ warnings.filterwarnings("ignore", message="Qdrant client version.*is incompatible.*")
46
+
47
+ app = FastAPI(title="Combined AI Agent with Qdrant Collections and Redis Session Management")
48
+
49
+ # Add CORS middleware
50
+ app.add_middleware(
51
+ CORSMiddleware,
52
+ allow_origins=["*"],
53
+ allow_credentials=True,
54
+ allow_methods=["*"],
55
+ allow_headers=["*"],
56
+ )
57
+
58
+ print([route.path for route in app.routes])
59
+ # Include routers
60
+
61
+ # app.include_router(main_chatbot_route)
62
+ app.include_router(Report_Generation_Router)
63
+ app.include_router(Redis_session_router)
64
+ app.include_router(redis_session_route_new)
65
+ # app.include_router(Redis_session_router_old)
66
+ app.include_router(Db_store_router)
67
+ app.include_router(s3_bucket_router)
68
+ app.include_router(Autoviz_router)
69
+ # app.include_router(main_chat_router_v3)
70
+ app.include_router(s3_viewer_router)
71
+
72
+ app.include_router(s3_bucket_router1)
73
+ app.include_router(main_chat_router_v4)
74
+ app.include_router(test_router)
75
+
76
+ # app.include_router(RetrievalQA_router)
77
+ app.include_router(main_chatbot_route_v2)
78
+ # app.include_router(Report_Generation_Router_v2)
79
+ app.include_router(Report_Generation_Router_ppt_v2)
80
+ app.include_router(Report_Generation_Router_pdf_v2)
81
+
82
+ # app.include_router(main_chat_router_v5)
83
+ app.include_router(main_chat_router_v6)
84
+ #==================================login and user management routes==================================
85
+ app.include_router(login_apis_router)
86
+ #---------------------------Qdrant router -----------------------
87
+
88
+ # app.include_router(Qdrant_router)
89
+ # ------------------- MIDDLEWARE -------------------
90
+
91
+ @app.middleware("http")
92
+ async def add_success_flag(request: Request, call_next):
93
+ response = await call_next(request)
94
+
95
+ # Only modify JSON responses
96
+ if "application/json" in response.headers.get("content-type", ""):
97
+ try:
98
+ body = b"".join([chunk async for chunk in response.body_iterator])
99
+ data = json.loads(body.decode())
100
+
101
+ # Add success flag
102
+ data["success"] = 200 <= response.status_code < 300
103
+
104
+ # Build new JSONResponse (auto handles Content-Length)
105
+ response = JSONResponse(
106
+ content=data,
107
+ status_code=response.status_code,
108
+ headers={k: v for k, v in response.headers.items() if k.lower() != "content-length"},
109
+ )
110
+ except Exception:
111
+ # fallback if response is not JSON parseable
112
+ pass
113
+ return response
114
+
115
+
backend/__pycache__/crud.cpython-311.pyc CHANGED
Binary files a/backend/__pycache__/crud.cpython-311.pyc and b/backend/__pycache__/crud.cpython-311.pyc differ
 
backend/__pycache__/main.cpython-311.pyc CHANGED
Binary files a/backend/__pycache__/main.cpython-311.pyc and b/backend/__pycache__/main.cpython-311.pyc differ
 
backend/__pycache__/models.cpython-311.pyc CHANGED
Binary files a/backend/__pycache__/models.cpython-311.pyc and b/backend/__pycache__/models.cpython-311.pyc differ
 
backend/__pycache__/schemas.cpython-311.pyc CHANGED
Binary files a/backend/__pycache__/schemas.cpython-311.pyc and b/backend/__pycache__/schemas.cpython-311.pyc differ
 
backend/__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/backend/__pycache__/utils.cpython-311.pyc and b/backend/__pycache__/utils.cpython-311.pyc differ
 
main_chat5/__pycache__/data_agent.cpython-311.pyc CHANGED
Binary files a/main_chat5/__pycache__/data_agent.cpython-311.pyc and b/main_chat5/__pycache__/data_agent.cpython-311.pyc differ
 
main_chat_v6/__pycache__/data_agent.cpython-311.pyc CHANGED
Binary files a/main_chat_v6/__pycache__/data_agent.cpython-311.pyc and b/main_chat_v6/__pycache__/data_agent.cpython-311.pyc differ
 
main_chat_v6/__pycache__/main_chat_router.cpython-311.pyc CHANGED
Binary files a/main_chat_v6/__pycache__/main_chat_router.cpython-311.pyc and b/main_chat_v6/__pycache__/main_chat_router.cpython-311.pyc differ
 
main_chat_v6/data_agent.py CHANGED
@@ -1,452 +1,448 @@
1
  """
2
- Improved Data Agent with Real-time Phased Streaming
3
- Phases: thinking (agent reasoning), answer (final response), charts (visualizations)
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import asyncio
7
  import json
 
8
  import re
9
- from typing import AsyncGenerator, List, Tuple
10
- import duckdb
 
 
11
  import pandas as pd
12
- from openai.types.responses import ResponseTextDeltaEvent
13
- from agents import Agent, Runner, function_tool
14
- from s3.read_files import read_parquet_from_s3
15
 
16
- # =====================================
17
- # CONFIGURATION
18
- # =====================================
 
19
 
20
- ALLOWED_SQL_KEYWORDS = ("select", "with")
 
 
 
 
 
 
 
21
  MAX_RESULT_ROWS = 10000
22
 
23
- _registered_datasets = {}
24
- last_sql_query: str = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # =====================================
27
- # DATASET LOADING
28
- # =====================================
29
 
30
  def generate_dataset_context(df: pd.DataFrame, table_name: str) -> str:
31
- """Generate comprehensive dataset context"""
32
- context_parts = []
33
- context_parts.append(f"Table: {table_name}")
34
- context_parts.append(f"Rows: {len(df)}, Columns: {len(df.columns)}")
35
- context_parts.append(f"Column Names: {', '.join(df.columns.tolist())}")
36
- context_parts.append("\nData Types:")
37
- for col, dtype in df.dtypes.items():
38
- context_parts.append(f" - {col}: {dtype}")
39
-
40
- context_parts.append("\nSample Data (first 3 rows):")
41
- sample_str = df.head(3).to_string(index=False, max_cols=10)
42
- context_parts.append(sample_str)
43
-
44
- numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
45
- if numeric_cols:
46
- context_parts.append("\nNumeric Column Statistics:")
47
- desc = df[numeric_cols].describe().transpose()
48
- for col in numeric_cols[:5]:
49
  if col in desc.index:
50
- context_parts.append(
51
- f" - {col}: mean={desc.loc[col, 'mean']:.2f}, "
52
- f"min={desc.loc[col, 'min']:.2f}, max={desc.loc[col, 'max']:.2f}"
53
- )
54
-
55
- return "\n".join(context_parts)
56
-
57
- def load_and_register_datasets(file_paths: List[str]) -> str:
58
- """Load datasets and register them for SQL queries"""
59
  global _registered_datasets
60
- _registered_datasets.clear()
61
- context_parts = ["Available datasets for SQL queries:\n"]
62
-
63
- for idx, file_path in enumerate(file_paths, start=1):
64
- table_name = f"CSV{idx}"
65
- try:
66
- print(f"Loading {file_path} as {table_name}")
67
- df = read_parquet_from_s3(file_path)
68
- _registered_datasets[table_name] = df
69
- dataset_context = generate_dataset_context(df, table_name)
70
- context_parts.append(f"\n{dataset_context}\n")
71
- context_parts.append("="*70)
72
- print(f"Loaded {table_name}: {len(df)} rows, {len(df.columns)} columns")
73
- except Exception as e:
74
- print(f"Error loading {file_path}: {str(e)}")
75
- raise
76
-
77
- return "\n".join(context_parts)
78
-
79
- # =====================================
80
- # SQL TOOL
81
- # =====================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  @function_tool
84
  async def sql_tool(sql_query: str, user_query: str = "") -> dict:
85
- """Execute SQL query and return results"""
86
- global _registered_datasets, last_sql_query
87
- last_sql_query = sql_query
88
-
89
- print(f"\n{'='*70}")
90
- print(f"Executing SQL:\n{sql_query}")
91
- print(f"{'='*70}\n")
92
-
93
- normalized_query = sql_query.strip().lower()
94
- if not normalized_query.startswith(ALLOWED_SQL_KEYWORDS):
95
- raise ValueError(
96
- f"Only {', '.join(k.upper() for k in ALLOWED_SQL_KEYWORDS)} queries allowed."
97
- )
98
 
99
  if not _registered_datasets:
100
- raise ValueError("No datasets loaded.")
101
 
102
- conn = duckdb.connect(":memory:")
103
  try:
104
- for name, df in _registered_datasets.items():
105
- conn.register(name, df)
106
- result_df = conn.execute(sql_query).df()
 
 
 
107
  except Exception as e:
108
- conn.close()
 
109
  raise ValueError(f"SQL execution failed: {str(e)}")
110
- finally:
111
- conn.close()
112
 
113
- if len(result_df) > MAX_RESULT_ROWS:
114
- result_df = result_df.head(MAX_RESULT_ROWS)
115
-
116
- summary_lines = [
117
- f"Returned {len(result_df)} rows, {len(result_df.columns)} columns"
118
- ]
119
- if not result_df.empty:
120
- summary_lines.append("\nPreview (first 10 rows):")
121
- summary_lines.append(result_df.head(10).to_string(index=False, max_rows=10))
122
-
123
- result_summary = "\n".join(summary_lines)
124
 
125
- print(f"SQL executed successfully\n")
 
 
 
 
126
  return {
127
  "sql_query": sql_query,
128
- "result_summary": result_summary,
129
- "columns": result_df.columns.tolist(),
130
- "row_count": len(result_df)
 
131
  }
132
 
133
- # =====================================
134
- # DATA ANALYST AGENT WITH STRUCTURED OUTPUT
135
- # =====================================
136
 
137
  def create_data_analyst_agent(dataset_context: str) -> Agent:
138
- """Create data analyst agent with explicit phase instructions"""
139
- return Agent(
140
- name="DataStory_AI",
141
- model="gpt-4o",
142
- tools=[sql_tool],
143
- instructions=f"""You are **DataStory AI**, an elite data analyst who combines deep analytical reasoning with compelling storytelling.
 
144
 
145
  DATASET CONTEXT:
146
  {dataset_context}
147
 
148
- ===== CRITICAL: STRUCTURED RESPONSE FORMAT =====
149
-
150
- You MUST structure your response in THREE distinct phases using XML tags:
151
-
152
- **PHASE 1: THINKING (Your Reasoning Process)**
153
- <thinking>
154
- - Explain your analytical approach step-by-step
155
- - Describe what SQL queries you'll run and WHY
156
- - Outline the insights you're looking for
157
- - Discuss any data quality considerations
158
- - Map out your visualization strategy
159
- Example:
160
- "To answer this question, I need to:
161
- 1. First examine the distribution of sales across regions by querying CSV1
162
- 2. Calculate year-over-year growth rates to identify trends
163
- 3. Identify top performers and outliers
164
- 4. Create visualizations showing: (a) regional comparison, (b) time series trend, (c) top 10 performers"
165
- </thinking>
166
-
167
- **PHASE 2: ANSWER (Your Analysis & Narrative)**
168
- <answer>
169
- Write your complete data story here:
170
- - Present key findings with specific numbers
171
- - Provide context and business implications
172
- - Explain patterns, trends, and anomalies
173
- - Use clear section headers
174
- - Be specific: "Sales increased by 23.5% in Q3" not "Sales increased significantly"
175
- </answer>
176
-
177
- **PHASE 3: CHARTS (Visualizations)**
178
- For each visualization, use this EXACT format:
179
-
180
- <datastory_chart>
181
- <chart_title>Descriptive Title</chart_title>
182
- <chart_description>What this chart shows with specific numbers</chart_description>
183
- <vega_spec>
184
- {{
185
- "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
186
- "title": {{"text": "Chart Title", "fontSize": 18, "fontWeight": "bold"}},
187
- "width": 700,
188
- "height": 400,
189
- "data": {{"values": [
190
- {{"category": "A", "value": 100}},
191
- {{"category": "B", "value": 150}}
192
- ]}},
193
- "mark": {{"type": "bar", "tooltip": true, "cornerRadius": 4}},
194
- "encoding": {{
195
- "x": {{"field": "category", "type": "nominal", "axis": {{"labelAngle": 0}}}},
196
- "y": {{"field": "value", "type": "quantitative", "axis": {{"title": "Value"}}}}
197
- }},
198
- "config": {{
199
- "axis": {{"labelFontSize": 12, "titleFontSize": 14}},
200
- "legend": {{"labelFontSize": 12, "titleFontSize": 14}}
201
- }}
202
- }}
203
- </vega_spec>
204
- </datastory_chart>
205
-
206
- ===== WORKFLOW =====
207
- 1. **START WITH THINKING**: Always begin with <thinking> tags explaining your analytical plan
208
- 2. **Execute SQL queries**: Use sql_tool(sql_query="...", user_query="...")
209
- 3. **WRITE ANSWER**: Present insights in <answer> tags with specific numbers
210
- 4. **CREATE CHARTS**: Generate 2-3 visualizations wrapped in <datastory_chart> tags
211
-
212
- ===== SQL QUERY RULES =====
213
- - Table names: CSV1, CSV2, CSV3, etc. (from dataset context above)
214
- - Only SELECT and WITH statements allowed
215
- - Always include user_query parameter in sql_tool calls
216
- - Use meaningful column aliases
217
- - Limit results appropriately for visualizations
218
-
219
- ===== VISUALIZATION BEST PRACTICES =====
220
- - Include ACTUAL data in "values" array (not placeholders)
221
- - Use appropriate chart types: bar, line, point, area, arc (pie)
222
- - Add tooltips: "mark": {{"type": "bar", "tooltip": true}}
223
- - Use clear axis labels and titles
224
- - Color schemes: use "color" encoding for categorical data
225
- - For time series: parse dates with "timeUnit"
226
-
227
- ===== QUALITY STANDARDS =====
228
- ✓ Be specific with numbers: "Revenue increased by $2.3M (18.7%)"
229
- ✓ Provide context: Compare to benchmarks, previous periods, or goals
230
- ✓ Explain WHY patterns exist, not just WHAT they are
231
- ✓ Structure with clear headers: ## Key Finding 1, ## Key Finding 2
232
- ✓ Generate 2-3 complementary visualizations
233
- ✓ End with actionable recommendations
234
-
235
- ===== EXAMPLE RESPONSE STRUCTURE =====
236
-
237
- <thinking>
238
- Let me analyze the sales data systematically:
239
-
240
- 1. **Data Exploration**: I'll query CSV1 to understand sales distribution across regions and time periods
241
- 2. **Trend Analysis**: Calculate monthly growth rates to identify seasonal patterns
242
- 3. **Performance Ranking**: Identify top 10 products by revenue
243
- 4. **Visualization Plan**:
244
- - Chart 1: Regional sales comparison (bar chart)
245
- - Chart 2: Monthly trend with growth rate (line chart)
246
- - Chart 3: Top 10 products (horizontal bar)
247
-
248
- SQL Strategy: First, aggregate sales by region, then by month, finally rank products.
249
- </thinking>
250
-
251
- <answer>
252
- ## Executive Summary
253
- Analysis of 50,000 transactions reveals strong regional disparities and seasonal trends...
254
-
255
- ## Regional Performance
256
- The West region dominates with $4.2M in sales (42% of total), followed by East at $3.1M (31%)...
257
-
258
- [Continue with detailed findings, specific numbers, and business implications]
259
- </answer>
260
-
261
- <datastory_chart>
262
- <chart_title>Sales by Region - West Region Leads with 42% Share</chart_title>
263
- <chart_description>West region generated $4.2M, 35% ahead of second-place East region ($3.1M)</chart_description>
264
- <vega_spec>
265
- {{
266
- "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
267
- "title": {{"text": "Sales by Region ($ Millions)", "fontSize": 18}},
268
- "width": 700,
269
- "height": 400,
270
- "data": {{"values": [
271
- {{"region": "West", "sales": 4.2}},
272
- {{"region": "East", "sales": 3.1}},
273
- {{"region": "North", "sales": 1.8}},
274
- {{"region": "South", "sales": 1.5}}
275
- ]}},
276
- "mark": {{"type": "bar", "tooltip": true, "color": "#4C78A8"}},
277
- "encoding": {{
278
- "x": {{"field": "region", "type": "nominal", "sort": "-y"}},
279
- "y": {{"field": "sales", "type": "quantitative", "title": "Sales ($M)"}}
280
- }}
281
- }}
282
- </vega_spec>
283
- </datastory_chart>
284
-
285
- REMEMBER: Always use the three-phase structure with XML tags. Your reasoning goes in <thinking>, analysis in <answer>, and charts in <datastory_chart> blocks.
286
  """
287
- )
288
-
289
- # =====================================
290
- # PHASE-AWARE CHART BUFFER
291
- # =====================================
292
-
293
- class PhaseAwareBuffer:
294
- """Buffer that separates content into thinking, answer, and chart phases"""
295
-
296
- def __init__(self):
297
- self.full_buffer = ""
298
- self.thinking_complete = False
299
- self.answer_complete = False
300
- self.extracted_charts = []
301
-
302
- def add_text(self, text: str) -> List[Tuple[str, str]]:
303
- """
304
- Add text and extract phase-separated content.
305
- Returns: List of (phase, content) tuples
306
- """
307
- self.full_buffer += text
308
- emissions = []
309
-
310
- # Extract thinking phase
311
- if not self.thinking_complete:
312
- thinking_match = re.search(r'<thinking>(.*?)</thinking>', self.full_buffer, re.DOTALL)
313
- if thinking_match:
314
- thinking_content = thinking_match.group(1).strip()
315
- emissions.append(("thinking", thinking_content))
316
- self.thinking_complete = True
317
- # Remove from buffer
318
- self.full_buffer = self.full_buffer[thinking_match.end():]
319
-
320
- # Extract answer phase
321
- if self.thinking_complete and not self.answer_complete:
322
- answer_match = re.search(r'<answer>(.*?)</answer>', self.full_buffer, re.DOTALL)
323
- if answer_match:
324
- answer_content = answer_match.group(1).strip()
325
- emissions.append(("answer", answer_content))
326
- self.answer_complete = True
327
- # Remove from buffer
328
- self.full_buffer = self.full_buffer[answer_match.end():]
329
-
330
- # Extract charts (can be multiple)
331
- if self.answer_complete:
332
- chart_pattern = re.compile(r'<datastory_chart>(.*?)</datastory_chart>', re.DOTALL)
333
- matches = list(chart_pattern.finditer(self.full_buffer))
334
-
335
- for match in matches:
336
- chart_content = match.group(1)
337
-
338
- # Parse Vega spec
339
- vega_pattern = r'<vega_spec>(.*?)</vega_spec>'
340
- vega_match = re.search(vega_pattern, chart_content, re.DOTALL)
341
-
342
- if vega_match:
343
- try:
344
- vega_json_str = vega_match.group(1).strip()
345
- vega_spec = json.loads(vega_json_str)
346
-
347
- # Only emit if not already extracted
348
- if vega_spec not in self.extracted_charts:
349
- self.extracted_charts.append(vega_spec)
350
- emissions.append(("charts", vega_spec))
351
- except json.JSONDecodeError as e:
352
- print(f"Failed to parse chart JSON: {e}")
353
-
354
- # Remove from buffer
355
- self.full_buffer = self.full_buffer[match.end():]
356
-
357
- return emissions
358
 
359
- # =====================================
360
- # PHASED STREAMING WITH AGENT REASONING
361
- # =====================================
362
 
363
- async def stream_analysis(query: str, file_paths: List[str]) -> AsyncGenerator[Tuple[str, str], None]:
 
 
364
  """
365
- Stream analysis with agent's actual reasoning in thinking phase.
366
- Yields: (phase, content) tuples where phase is:
367
- - THINKING: Agent's reasoning process
368
- - ANSWER: Final analysis and narrative
369
- - CHART: Vega-Lite chart specification
370
- - ERROR, DONE: Control signals
371
  """
372
- print(f"\n{'='*70}")
373
- print("ANALYSIS REQUEST (phased streaming with agent reasoning)")
374
- print(f"Files: {', '.join(file_paths)}")
375
- print(f"Query: {query}")
376
- print(f"{'='*70}\n")
377
-
378
- # Load datasets
379
  try:
380
  dataset_context = load_and_register_datasets(file_paths)
381
- print(f"\nDataset Context loaded: {len(dataset_context)} chars\n")
382
  except Exception as e:
 
383
  yield ("ERROR", f"Failed to load datasets: {str(e)}")
384
  yield ("DONE", "")
385
  return
386
 
387
- # Create agent
388
  agent = create_data_analyst_agent(dataset_context)
389
-
390
- # Initialize phase-aware buffer
391
- phase_buffer = PhaseAwareBuffer()
392
-
393
  try:
394
  result = Runner.run_streamed(agent, input=query)
395
-
396
- async for event in result.stream_events():
397
- if (
398
- event.type == "raw_response_event"
399
- and isinstance(event.data, ResponseTextDeltaEvent)
400
- ):
401
- delta: str = event.data.delta
402
-
403
- # Add to buffer and extract phase-separated content
404
- emissions = phase_buffer.add_text(delta)
405
-
406
- for phase, content in emissions:
407
- if phase == "thinking":
408
- # Stream thinking in chunks for better UX
409
- sentences = content.split('\n')
410
- for sentence in sentences:
411
- if sentence.strip():
412
- yield ("THINKING", sentence.strip() + "\n")
413
- await asyncio.sleep(0.05)
414
-
415
- elif phase == "answer":
416
- # Stream answer in chunks
417
- paragraphs = content.split('\n\n')
418
- for para in paragraphs:
419
- if para.strip():
420
- yield ("ANSWER", para.strip() + "\n\n")
421
- await asyncio.sleep(0.05)
422
-
423
- elif phase == "charts":
424
- # Emit complete chart
425
- yield ("CHART", content)
426
- await asyncio.sleep(0.1)
427
 
428
- except Exception as e:
429
- yield ("ERROR", f"Streaming error: {str(e)}")
430
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  yield ("DONE", "")
432
 
433
- # =====================================
434
- # LEGACY STREAMING (BACKWARD COMPATIBILITY)
435
- # =====================================
436
 
437
- async def stream_analysis_legacy(query: str, file_paths: List[str]) -> AsyncGenerator[str, None]:
438
- """
439
- Original streaming without phases - for backward compatibility
440
- """
441
- dataset_context = load_and_register_datasets(file_paths)
442
- agent = create_data_analyst_agent(dataset_context)
443
- result = Runner.run_streamed(agent, input=query)
444
-
445
- async for event in result.stream_events():
446
- if (
447
- event.type == "raw_response_event"
448
- and isinstance(event.data, ResponseTextDeltaEvent)
449
- ):
450
- yield event.data.delta
451
-
452
- yield "[DONE]"
 
1
  """
2
+ data_agent_v3.py
3
+
4
+ Robust production-ready data agent with:
5
+ - RobustPhaseDetector: hybrid phase detection (heuristics + JSON balancing + tool-call hooks)
6
+ - sql_tool: runs SQL via DuckDB safely using asyncio.to_thread + timeout
7
+ - load_and_register_datasets: loads Parquet/CSV via your s3.read_files helper
8
+ - stream_analysis: streams model output, detects phases robustly, and yields (phase, content) tuples
9
+
10
+ Notes:
11
+ - Replace `Runner.run_streamed` with your actual model runner
12
+ - `ResponseTextDeltaEvent` from OpenAI SDK is used as an example; adapt if your runner emits a different shape
13
  """
14
 
15
  import asyncio
16
  import json
17
+ import logging
18
  import re
19
+ from typing import AsyncGenerator, Dict, List, Tuple, Optional
20
+ from dataclasses import dataclass, field
21
+ from concurrent.futures import ThreadPoolExecutor
22
+
23
  import pandas as pd
24
+ import duckdb
 
 
25
 
26
+ # Import your Runner & ResponseTextDeltaEvent types (adapt if different)
27
+ from openai.types.responses import ResponseTextDeltaEvent # keep for type-checking; adapt as needed
28
+ from agents import Agent, Runner, function_tool # adapt to your agent system
29
+ from s3.read_files import read_parquet_from_s3 # existing helper to load Parquet from S3
30
 
31
+ logger = logging.getLogger("data_agent_v3")
32
+ logging.basicConfig(level=logging.INFO)
33
+
34
+ # ========= Configuration =========
35
+ DUCKDB_FILE = "agent_duckdb.db" # file-backed DB to allow connection reuse
36
+ DUCKDB_THREADS = 2
37
+ DUCKDB_MEMORY_LIMIT = "1GB"
38
+ SQL_TIMEOUT_SECONDS = 15 # per-query timeout
39
  MAX_RESULT_ROWS = 10000
40
 
41
+ # ThreadPoolExecutor for blocking DuckDB operations
42
+ _BLOCKING_EXECUTOR = ThreadPoolExecutor(max_workers=8)
43
+
44
+
45
+ # ========= DuckDB manager (safe, pooled via to_thread) =========
46
+ class DuckDBManager:
47
+ """
48
+ Simple manager that returns a connection that's safe to use from a thread.
49
+ We use thread-based execution of DuckDB queries via asyncio.to_thread to avoid blocking the event loop.
50
+ """
51
+
52
+ def __init__(self, db_path: str = DUCKDB_FILE):
53
+ self.db_path = db_path
54
+ # A lightweight connection initializer to warm pragmas
55
+ self._init_db()
56
+
57
+ def _init_db(self):
58
+ # Establish a short-lived connection and set pragmas
59
+ conn = duckdb.connect(self.db_path)
60
+ try:
61
+ conn.execute(f"PRAGMA threads={DUCKDB_THREADS}")
62
+ # Memory limit (DuckDB may accept memory_limit pragma as a string in newer versions)
63
+ try:
64
+ conn.execute(f"PRAGMA memory_limit='{DUCKDB_MEMORY_LIMIT}'")
65
+ except Exception:
66
+ # Older duckdb versions may ignore
67
+ pass
68
+ finally:
69
+ conn.close()
70
+
71
+ def execute_query(self, sql: str, registered_tables: Dict[str, pd.DataFrame]) -> pd.DataFrame:
72
+ """
73
+ Blocking function: executes sql and returns a pandas DataFrame.
74
+ Designed to run in a worker thread via asyncio.to_thread.
75
+ """
76
+ conn = duckdb.connect(self.db_path)
77
+ try:
78
+ # Register provided DataFrames into the connection
79
+ for name, df in registered_tables.items():
80
+ conn.register(name, df)
81
+ # Execute query
82
+ res = conn.execute(sql).fetchdf()
83
+ return res
84
+ finally:
85
+ conn.close()
86
+
87
+
88
+ # Create global manager instance
89
+ _duckdb_manager = DuckDBManager()
90
+
91
+
92
+ # ========= Dataset loading =========
93
+ _registered_datasets: Dict[str, pd.DataFrame] = {}
94
 
 
 
 
95
 
96
  def generate_dataset_context(df: pd.DataFrame, table_name: str) -> str:
97
+ parts = [f"Table: {table_name}", f"Rows: {len(df)}, Columns: {len(df.columns)}", f"Column Names: {', '.join(df.columns.tolist())}"]
98
+ parts.append("\nSample (first 3 rows):")
99
+ parts.append(df.head(3).to_string(index=False, max_cols=10))
100
+ numeric = df.select_dtypes(include=["number"]).columns.tolist()
101
+ if numeric:
102
+ parts.append("\nNumeric stats (first 5 numeric cols):")
103
+ desc = df[numeric].describe().transpose()
104
+ for col in numeric[:5]:
 
 
 
 
 
 
 
 
 
 
105
  if col in desc.index:
106
+ parts.append(f" - {col}: mean={desc.loc[col,'mean']:.2f}, min={desc.loc[col,'min']:.2f}, max={desc.loc[col,'max']:.2f}")
107
+ return "\n".join(parts)
108
+
109
+
110
+ def load_and_register_datasets(file_paths):
 
 
 
 
111
  global _registered_datasets
112
+ _registered_datasets = {}
113
+
114
+ ctx_parts = []
115
+ for i, fp in enumerate(file_paths, start=1):
116
+ table_name = f"CSV{i}"
117
+ df = read_parquet_from_s3(fp)
118
+
119
+ # 🔥 FIX COLUMN NAMES HERE
120
+ df.columns = (
121
+ df.columns
122
+ .str.strip()
123
+ .str.replace(" ", "_")
124
+ .str.replace(r"[^\w]+", "", regex=True)
125
+ )
126
+
127
+ _registered_datasets[table_name] = df
128
+ ctx_parts.append(generate_dataset_context(df, table_name))
129
+
130
+ return "\n\n".join(ctx_parts)
131
+
132
+
133
+
134
+ # ========= Robust Phase Detector =========
135
+
136
+ @dataclass
137
+ class RobustPhaseDetector:
138
+ """
139
+ Hybrid phase detector that combines:
140
+ - soft XML/tag hints (best-effort)
141
+ - content heuristics (thinking/reasoning/answer)
142
+ - tool-call hooks (you should emit explicit tokens in Runner for tool calls)
143
+ - JSON brace balancing for charts
144
+ """
145
+ buffer: str = ""
146
+ thinking_emitted: bool = False
147
+ reasoning_emitted: bool = False
148
+ answer_started: bool = False
149
+ charts: List[Dict] = field(default_factory=list)
150
+ brace_count: int = 0
151
+ current_json_chunk: str = ""
152
+ chart_candidates: List[str] = field(default_factory=list)
153
+
154
+ thinking_keywords = [
155
+ "let me", "first", "need to", "i will", "step", "query", "sql", "check", "analyze", "understand",
156
+ "explore", "calculate", "group by", "join", "filter", "plan", "strategy"
157
+ ]
158
+ reasoning_keywords = [
159
+ "i will run", "running query", "execute", "to compute", "will compute", "plan to", "step 1", "step 2",
160
+ ]
161
+ answer_markers = [
162
+ "results show", "analysis reveals", "key finding", "## ", "conclusion", "in summary", "the results",
163
+ "final", "therefore", "overall", "we find"
164
+ ]
165
+
166
+ def feed(self, delta: str) -> List[Tuple[str, object]]:
167
+ """
168
+ Accept raw streamed text (delta) and return list of (phase, content)
169
+ phases: THINKING, REASONING, ANSWER, CHART, ERROR, DONE
170
+ """
171
+ emissions: List[Tuple[str, object]] = []
172
+ if not delta:
173
+ return emissions
174
+
175
+ # Append delta to buffer for heuristic detection
176
+ self.buffer += delta
177
+
178
+ # --------- 1) JSON/Chart detection via brace balancing ----------
179
+ # Look for a possible JSON start (with "$schema" or vega indicator) in incoming data
180
+ # We try to capture complete JSON objects using brace balancing.
181
+ # If JSON is extracted successfully, emit CHART with parsed dict.
182
+ # Do this before other heuristics so charts show asap.
183
+ # Scan buffer char by char (efficient on chunk sizes)
184
+ i = 0
185
+ while i < len(delta):
186
+ ch = delta[i]
187
+ if ch == "{":
188
+ # if starting new JSON, and previous brace_count==0, start capturing
189
+ if self.brace_count == 0:
190
+ # start new chunk
191
+ self.current_json_chunk = "{"
192
+ else:
193
+ self.current_json_chunk += "{"
194
+ self.brace_count += 1
195
+ elif ch == "}":
196
+ self.current_json_chunk += "}"
197
+ self.brace_count -= 1
198
+ if self.brace_count == 0:
199
+ # Potential full JSON captured
200
+ candidate = self.current_json_chunk.strip()
201
+ # quick check for vega/schema
202
+ if '"$schema"' in candidate or "vega" in candidate.lower() or "data" in candidate.lower():
203
+ try:
204
+ parsed = json.loads(candidate)
205
+ # dedupe by content
206
+ if parsed not in self.charts:
207
+ self.charts.append(parsed)
208
+ emissions.append(("CHART", parsed))
209
+ # remove candidate from buffer to avoid reprocessing
210
+ self.buffer = self.buffer.replace(candidate, "", 1)
211
+ except Exception:
212
+ # Not yet valid JSON - maybe streaming partial, continue trying next deltas
213
+ # Keep candidate in current_json_chunk for next feed
214
+ pass
215
+ self.current_json_chunk = ""
216
+ else:
217
+ if self.brace_count > 0:
218
+ self.current_json_chunk += ch
219
+ i += 1
220
+
221
+ # --------- 2) Tool-call / reasoning boundary detection ----------
222
+ # If Runner emits explicit markers for tool calls, you should hook into them.
223
+ # Here we support a common pattern: if buffer contains something like "[TOOL_CALL_START]" or "SQL TOOL:",
224
+ # we treat it as REASONING beginning.
225
+ tool_markers = ["[TOOL_CALL_START]", "[SQL_TOOL]", "sql_tool(", "EXECUTE SQL:", "RUN SQL:"]
226
+ lower_buffer = self.buffer.lower()
227
+
228
+ # THINKING detection (soft)
229
+ if not self.thinking_emitted:
230
+ if len(self.buffer) > 80 and any(kw in lower_buffer for kw in self.thinking_keywords):
231
+ snippet = self._take_and_trim_for_emission(self.buffer, max_chars=1000)
232
+ emissions.append(("THINKING", snippet))
233
+ self.thinking_emitted = True
234
+ # keep buffer for subsequent phases
235
+ # REASONING detection
236
+ if not self.reasoning_emitted:
237
+ if any(m.lower() in lower_buffer for m in tool_markers) or any(kw in lower_buffer for kw in self.reasoning_keywords):
238
+ snippet = self._take_and_trim_for_emission(self.buffer, max_chars=2000)
239
+ emissions.append(("REASONING", snippet))
240
+ self.reasoning_emitted = True
241
+
242
+ # ANSWER detection
243
+ if (self.thinking_emitted or self.reasoning_emitted) and not self.answer_started:
244
+ if any(marker in lower_buffer for marker in self.answer_markers) or len(self.buffer) > 1500:
245
+ # consider this the start of the answer/narrative
246
+ snippet = self._take_and_trim_for_emission(self.buffer, max_chars=4000)
247
+ emissions.append(("ANSWER", snippet))
248
+ self.answer_started = True
249
+ # keep buffer for subsequent answer growth
250
+
251
+ # Also attempt to detect inline "final" answer markers even after answer_started to emit deltas
252
+ if self.answer_started:
253
+ # Split newly appended buffer into reasonable paragraph chunks for streaming UX
254
+ new_paragraphs = self._drain_buffer_into_paragraphs()
255
+ for p in new_paragraphs:
256
+ emissions.append(("ANSWER", p))
257
+
258
+ return emissions
259
+
260
+ def _take_and_trim_for_emission(self, buf: str, max_chars: int = 1000) -> str:
261
+ """Return upto max_chars of buffer for emission (non-destructive)"""
262
+ snippet = buf[:max_chars]
263
+ return snippet.strip()
264
+
265
+ def _drain_buffer_into_paragraphs(self) -> List[str]:
266
+ """
267
+ Extract paragraphs from buffer (up to some limit), leaving remainder in buffer.
268
+ This helps stream incremental answer paragraphs rather than massive blobs.
269
+ """
270
+ # Split on two newlines or single newline when long
271
+ paragraphs = []
272
+ # If very long buffer, cut into 800-char chunks
273
+ if len(self.buffer) > 0:
274
+ if "\n\n" in self.buffer:
275
+ parts = self.buffer.split("\n\n")
276
+ else:
277
+ # chunky split
278
+ parts = [self.buffer[i:i+800] for i in range(0, len(self.buffer), 800)]
279
+ # emit first 1-3 parts and keep rest
280
+ take = min(3, len(parts))
281
+ for _ in range(take):
282
+ part = parts.pop(0).strip()
283
+ if part:
284
+ paragraphs.append(part)
285
+ # reconstruct buffer with remaining parts
286
+ self.buffer = "\n\n".join(parts)
287
+ return paragraphs
288
+
289
+
290
+ # ========= SQL TOOL =========
291
 
292
  @function_tool
293
  async def sql_tool(sql_query: str, user_query: str = "") -> dict:
294
+ """
295
+ Async SQL tool that executes on DuckDB using asyncio.to_thread (thread-safe)
296
+ - Enforces a per-query timeout using asyncio.wait_for
297
+ - Registers in-memory pandas DataFrames for tables (CSV1, CSV2, ...)
298
+ """
299
+ if not sql_query:
300
+ raise ValueError("Empty SQL query")
301
+
302
+ normalized = sql_query.strip().lower()
303
+ if not (normalized.startswith("select") or normalized.startswith("with")):
304
+ raise ValueError("Only SELECT/WITH queries allowed for safety")
 
 
305
 
306
  if not _registered_datasets:
307
+ raise ValueError("No datasets registered for SQL execution")
308
 
309
+ # Execute in a thread pool and enforce timeout
310
  try:
311
+ df = await asyncio.wait_for(
312
+ asyncio.to_thread(_duckdb_manager.execute_query, sql_query, _registered_datasets),
313
+ timeout=SQL_TIMEOUT_SECONDS,
314
+ )
315
+ except asyncio.TimeoutError:
316
+ raise TimeoutError(f"SQL execution exceeded {SQL_TIMEOUT_SECONDS}s timeout")
317
  except Exception as e:
318
+ # Bubble up a clear error
319
+ logger.exception("SQL execution failed")
320
  raise ValueError(f"SQL execution failed: {str(e)}")
 
 
321
 
322
+ # Trim large results
323
+ if len(df) > MAX_RESULT_ROWS:
324
+ df = df.head(MAX_RESULT_ROWS)
 
 
 
 
 
 
 
 
325
 
326
+ # Build summary
327
+ result_summary = [f"Returned {len(df)} rows, {len(df.columns)} columns"]
328
+ if not df.empty:
329
+ result_summary.append("\nPreview (first 10 rows):")
330
+ result_summary.append(df.head(10).to_string(index=False, max_rows=10))
331
  return {
332
  "sql_query": sql_query,
333
+ "result_summary": "\n".join(result_summary),
334
+ "columns": df.columns.tolist(),
335
+ "row_count": len(df),
336
+ "data_preview": df.head(50).to_dict(orient="records") # useful for immediate charts
337
  }
338
 
339
+
340
+ # ========= Agent creation utility =========
 
341
 
342
  def create_data_analyst_agent(dataset_context: str) -> Agent:
343
+ """
344
+ Build the agent with detailed instructions. Keep structured hints, but DO NOT rely on them.
345
+ The Runner should stream raw text; the RobustPhaseDetector will detect phases.
346
+ """
347
+ instructions = f"""
348
+ You are DataStory AI. Use the dataset context below to plan SQL queries, run them with sql_tool(sql_query="..."),
349
+ and produce a final data story. You may include chart JSON blobs (Vega-Lite) in your output; when you do, ensure valid JSON.
350
 
351
  DATASET CONTEXT:
352
  {dataset_context}
353
 
354
+ Important: Your structured tags (like <thinking>) are optional; the system detects phases heuristically.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  """
356
+ return Agent(name="DataStory_AI", model="gpt-4o", tools=[sql_tool], instructions=instructions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
 
 
 
358
 
359
+ # ========= Streaming analysis =========
360
+
361
+ async def stream_analysis(query: str, file_paths: List[str]) -> AsyncGenerator[Tuple[str, object], None]:
362
  """
363
+ Stream analysis for data queries. Yields tuples (phase_tag, content)
364
+ Phases: THINKING, REASONING, ANSWER, CHART, ERROR, DONE
 
 
 
 
365
  """
366
+ logger.info("Starting stream_analysis: query=%s files=%s", query, file_paths)
367
+ # 1) Load datasets
 
 
 
 
 
368
  try:
369
  dataset_context = load_and_register_datasets(file_paths)
 
370
  except Exception as e:
371
+ logger.exception("Failed to load datasets")
372
  yield ("ERROR", f"Failed to load datasets: {str(e)}")
373
  yield ("DONE", "")
374
  return
375
 
 
376
  agent = create_data_analyst_agent(dataset_context)
377
+ detector = RobustPhaseDetector()
378
+
379
+ # IMPORTANT: Runner.run_streamed must produce an object with .stream_events() or be adapted.
380
+ # The event stream must yield events that contain incremental text deltas, e.g. ResponseTextDeltaEvent
381
  try:
382
  result = Runner.run_streamed(agent, input=query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
+ # Example: the Runner here is assumed to have an async generator `stream_events()`
385
+ async for event in result.stream_events():
386
+ # If your runner provides a different event shape, adapt this block
387
+ # We look for ResponseTextDeltaEvent style events containing .data.delta
388
+ try:
389
+ if getattr(event, "type", "") == "raw_response_event" and isinstance(getattr(event, "data", None), ResponseTextDeltaEvent):
390
+ delta = event.data.delta
391
+ emissions = detector.feed(delta)
392
+
393
+ for phase, content in emissions:
394
+ # Normalize outputs so router can wrap them into SSE easily
395
+ if phase == "THINKING":
396
+ # Stream line-by-line for better UX
397
+ for line in str(content).splitlines():
398
+ if line.strip():
399
+ yield ("THINKING", line.strip())
400
+ elif phase == "REASONING":
401
+ # REASONING can be multi-line; stream in paragraphs
402
+ for para in str(content).split("\n\n"):
403
+ if para.strip():
404
+ yield ("REASONING", para.strip())
405
+ elif phase == "ANSWER":
406
+ # Answer may be paragraph chunk
407
+ yield ("ANSWER", content if isinstance(content, str) else str(content))
408
+ elif phase == "CHART":
409
+ # content is parsed JSON dict
410
+ yield ("CHART", content)
411
+ else:
412
+ yield ("ANSWER", str(content))
413
+ else:
414
+ # Unknown event type: try to extract text if possible
415
+ # Many runners emit plain delta strings
416
+ delta_text = None
417
+ # try attributes often present
418
+ if hasattr(event, "data") and isinstance(event.data, str):
419
+ delta_text = event.data
420
+ elif hasattr(event, "delta"):
421
+ delta_text = event.delta
422
+ elif isinstance(event, str):
423
+ delta_text = event
424
+
425
+ if delta_text:
426
+ emissions = detector.feed(delta_text)
427
+ for phase, content in emissions:
428
+ if phase == "CHART":
429
+ yield ("CHART", content)
430
+ else:
431
+ yield (phase, content)
432
+ except Exception as inner_e:
433
+ logger.exception("Error processing event chunk")
434
+ yield ("ERROR", f"Processing error: {str(inner_e)}")
435
+
436
+ except asyncio.CancelledError:
437
+ logger.info("stream_analysis cancelled (client disconnected or server shutdown)")
438
+ yield ("ERROR", "Stream cancelled by client")
439
+ except Exception as exc:
440
+ logger.exception("stream_analysis uncaught exception")
441
+ yield ("ERROR", f"Streaming error: {str(exc)}")
442
+
443
+ # final done sentinel
444
  yield ("DONE", "")
445
 
 
 
 
446
 
447
+ # ========= Legacy helper (optional) =========
448
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main_chat_v6/main_chat_router.py CHANGED
@@ -1,41 +1,49 @@
1
  """
2
- Main Router with True Agent Reasoning in Thinking Phase
3
- Phases: thinking (agent's reasoning), answer (analysis), charts (visualizations)
 
 
 
 
4
  """
5
 
6
- from fastapi import APIRouter, HTTPException
7
- from fastapi.responses import StreamingResponse
8
- from datetime import datetime
9
- from typing import AsyncGenerator, List, Optional, Dict, Tuple
10
  import json
11
- import asyncio
12
- import os
13
- import sys
14
  import uuid
15
- from pathlib import Path
16
- import httpx
17
- import re
 
18
 
19
- PROJECT_ROOT = Path(__file__).resolve().parents[2]
20
- if str(PROJECT_ROOT) not in sys.path:
21
- sys.path.insert(0, str(PROJECT_ROOT))
22
 
23
- from Redis.sessions_new import *
24
- from main_chat_v6.data_agent import stream_analysis
25
- from main_chat_v6.multiagent import stream_multiagent
 
 
 
 
 
 
 
 
 
 
 
 
26
  from pydantic import BaseModel
27
- from dotenv import load_dotenv
28
- from retrieve_secret import *
29
 
30
- load_dotenv()
 
 
 
31
 
32
- # =====================================
33
- # MODELS
34
- # =====================================
35
 
36
  class Author(BaseModel):
37
  role: str
38
 
 
39
  class ChatRequest(BaseModel):
40
  user_login_id: str
41
  session_id: Optional[str] = None
@@ -47,423 +55,184 @@ class ChatRequest(BaseModel):
47
  create_time: Optional[str] = None
48
  token: Optional[str] = None
49
 
50
- class ChatResponse(BaseModel):
51
- session_id: str
52
- user_message: str
53
- assistant_response: Dict
54
- is_new_session: bool
55
- session_title: str
56
- timestamp: str
57
- files_processed: Optional[List[str]] = None
58
- artifacts_paths: Optional[List[str]] = None
59
-
60
- main_chat_router_v6 = APIRouter(prefix="/main_chatbot/v6", tags=["main_chatbot_v6"])
61
-
62
- # =====================================
63
- # HELPER FUNCTIONS
64
- # =====================================
65
 
66
  async def make_convo_id(suffix: str = "ss") -> List[str]:
67
- """Generate UUID with suffix in mutable list"""
68
  base_uuid = str(uuid.uuid4())
69
  convo_id = base_uuid[:-2] + suffix
70
  return [convo_id]
71
 
72
- def determine_agent_type(request: ChatRequest) -> bool:
73
- """Determine which agent to use. Returns True for data_analyst."""
74
- if request.metadata and "file_paths" in request.metadata:
75
- file_paths = request.metadata["file_paths"]
76
- if file_paths and len(file_paths) > 0:
77
- return True
78
-
79
- query_lower = request.query.lower()
80
- data_keywords = [
81
- "analyze", "analysis", "dataset", "csv", "data", "statistics",
82
- "visualize", "visualization", "chart", "graph", "plot",
83
- "sql", "query", "select", "table", "column", "row"
84
- ]
85
-
86
- return any(keyword in query_lower for keyword in data_keywords)
87
-
88
- def create_phase_chunk(content: str, phase: str) -> dict:
89
- """Create standardized phase chunk"""
90
- return {
91
- "type": "chat:completion",
92
- "data": {
93
- "delta_content": content,
94
- "phase": phase
95
- }
96
- }
97
 
98
- async def store_conversation_to_db(
99
- user_login_id: str,
100
- session_id: str,
101
- user_query: str,
102
- ai_response: str,
103
- metadata: Dict[str, any],
104
- artifacts: list = None,
105
- suffix: str = "ss",
106
- convo_ids: List[str] = None,
107
- db_api_url: str = "https://mr-mvp-api-dev.dev.ingenspark.com/Db_store_router/conversations"
108
- ) -> Optional[Dict]:
109
- """Store conversation to database API"""
110
  try:
111
- if convo_ids is None or not convo_ids:
112
- convo_ids = await make_convo_id(suffix)
113
-
114
- convo_id = convo_ids[0]
115
- query_id = f"q_{uuid.uuid4()}"
116
-
117
- payload = {
118
- "user_id": user_login_id,
119
- "convo_id": convo_id,
120
- "data": {
121
- "user_query": {
122
- "query_id": query_id,
123
- "text": user_query,
124
- "user_metadata": {
125
- "location": metadata.get("user", {}).get("location", "Unknown"),
126
- "language": metadata.get("user", {}).get("language", "English")
127
- }
128
- },
129
- "response": {
130
- "text": ai_response,
131
- "status": "success",
132
- "response_time": datetime.now().isoformat() + "Z",
133
- "duration": f"{metadata.get('processing_time', 0):.2f}s",
134
- "artifacts": artifacts if artifacts else []
135
- },
136
- "metadata": {
137
- "processing_node": {
138
- "model_version": metadata.get("model_used", "gpt-4o"),
139
- "api_latency_ms": int(metadata.get("processing_time", 0) * 1000),
140
- "search_mode": metadata.get("search_mode", "sql_analysis")
141
- }
142
- }
143
- },
144
- "is_saved": False
145
- }
146
-
147
  async with httpx.AsyncClient(timeout=30.0) as client:
148
- response = await client.post(db_api_url, json=payload, headers={
149
- "accept": "application/json",
150
- "Content-Type": "application/json"
151
- })
152
- response.raise_for_status()
153
- return response.json()
154
-
155
  except Exception as e:
156
- print(f"Error storing conversation to DB: {e}")
157
  return None
158
 
159
- # =====================================
160
- # PHASED STREAMING WITH AGENT REASONING
161
- # =====================================
162
-
163
- async def stream_chat_response_phased(
164
- request: ChatRequest,
165
- session_id: str,
166
- is_new_session: bool,
167
- conversation_history: List,
168
- use_data_analyst: bool,
169
- convo_ids: List[str]
170
- ) -> AsyncGenerator[str, None]:
171
- """
172
- Stream response with phases from agent:
173
- - thinking: Agent's actual reasoning process
174
- - answer: Analysis and narrative
175
- - charts: Visualizations
176
- """
177
- full_response = ""
178
- artifacts = []
179
- user_message_id = str(uuid.uuid4())
180
- ai_message_id = str(uuid.uuid4())
181
- request_id = f"req_{uuid.uuid4().hex[:26].upper()}"
182
- start_time = datetime.now()
183
-
184
- thinking_buffer = ""
185
- answer_buffer = ""
186
-
187
- try:
188
- # ===== INITIAL METADATA =====
189
- initial_metadata = {
190
- "status": "ok",
191
- "api_version": "v1",
192
- "request_id": request_id,
193
- "event": "stream_started",
194
- "data": {
195
- "session_id": session_id,
196
- "user_message_id": user_message_id,
197
- "ai_message_id": ai_message_id,
198
- "is_new_session": is_new_session,
199
- "agent_type": "data_analyst_gpt4o" if use_data_analyst else "multiagent",
200
- "timestamp": datetime.now().timestamp(),
201
- "userLoginId": request.user_login_id
202
- }
203
- }
204
- yield f"data: {json.dumps(initial_metadata)}\n\n"
205
- await asyncio.sleep(0)
206
-
207
- # ===== STREAM FROM AGENT =====
208
- if use_data_analyst:
209
- file_paths = request.metadata.get("file_paths", []) if request.metadata else []
210
-
211
- if not file_paths:
212
- error_msg = "Data analyst requires file_paths in metadata"
213
- yield f"data: {json.dumps({'error': error_msg})}\n\n"
214
- yield "data: [DONE]\n\n"
215
- return
216
-
217
- chart_count = 0
218
-
219
- # Stream from data analyst agent
220
- async for phase, content in stream_analysis(request.query, file_paths):
221
- if phase == "DONE":
222
- break
223
-
224
- elif phase == "ERROR":
225
- error_response = {
226
- "status": "error",
227
- "event": "error_response",
228
- "data": {
229
- "error": {
230
- "code": "STREAM_ERROR",
231
- "message": content,
232
- "details": {}
233
- },
234
- "request_id": request_id,
235
- "timestamp": datetime.now().timestamp()
236
- }
237
- }
238
- yield f"data: {json.dumps(error_response)}\n\n"
239
-
240
- elif phase == "THINKING":
241
- # Agent's reasoning - stream as thinking phase
242
- thinking_buffer += content
243
- phase_chunk = create_phase_chunk(content, "thinking")
244
- yield f"data: {json.dumps(phase_chunk)}\n\n"
245
- await asyncio.sleep(0)
246
-
247
- elif phase == "ANSWER":
248
- # Agent's analysis - stream as answer phase
249
- answer_buffer += content
250
- full_response += content
251
- phase_chunk = create_phase_chunk(content, "answer")
252
- yield f"data: {json.dumps(phase_chunk)}\n\n"
253
- await asyncio.sleep(0)
254
-
255
- elif phase == "CHART":
256
- # Visualization artifact
257
- chart_count += 1
258
- artifacts.append(content)
259
-
260
- chart_event = {
261
- "type": "chart:artifact",
262
- "data": {
263
- "chart_index": chart_count,
264
- "vega_spec": content,
265
- "phase": "charts"
266
- }
267
- }
268
- yield f"data: {json.dumps(chart_event)}\n\n"
269
- await asyncio.sleep(0.1)
270
-
271
- else:
272
- # Multiagent flow (if needed)
273
- async for chunk in stream_multiagent(request.query):
274
- if chunk == "[DONE]":
275
- break
276
- elif chunk.startswith("[ERROR]"):
277
- error_response = {
278
- "status": "error",
279
- "event": "error_response",
280
- "data": {
281
- "error": {
282
- "code": "STREAM_ERROR",
283
- "message": chunk[7:],
284
- "details": {}
285
- },
286
- "request_id": request_id,
287
- "timestamp": datetime.now().timestamp()
288
- }
289
- }
290
- yield f"data: {json.dumps(error_response)}\n\n"
291
- else:
292
- full_response += chunk
293
- phase_chunk = create_phase_chunk(chunk, "answer")
294
- yield f"data: {json.dumps(phase_chunk)}\n\n"
295
- await asyncio.sleep(0)
296
-
297
- # ===== FINAL PAYLOAD =====
298
- processing_time = (datetime.now() - start_time).total_seconds()
299
- session_data = get_session(request.user_login_id, session_id)
300
-
301
- # Combine thinking + answer for full response storage
302
- full_response_for_storage = thinking_buffer + "\n\n" + answer_buffer
303
-
304
- final_payload = {
305
- "status": "ok",
306
- "api_version": "v1",
307
- "request_id": request_id,
308
- "event": "message_response",
309
- "data": {
310
- "type": "assistant_message",
311
- "conversation_id": convo_ids[0],
312
- "session_id": session_id,
313
- "user_message_id": user_message_id,
314
- "ai_message_id": ai_message_id,
315
- "userLoginId": request.user_login_id,
316
- "agent_used": "data_analyst_gpt4o" if use_data_analyst else "multiagent",
317
- "timestamp": datetime.now().timestamp(),
318
- "title": session_data.get("title", "New Chat"),
319
- "total_messages": len(get_message_history(request.user_login_id, session_id)),
320
- "message": {
321
- "id": ai_message_id,
322
- "author": {"role": "assistant", "name": "DataStory AI"},
323
- "metadata": {
324
- "model_used": "gpt-4o",
325
- "processing_time": processing_time,
326
- "thinking_length": len(thinking_buffer),
327
- "answer_length": len(answer_buffer),
328
- "charts_generated": len(artifacts)
329
- },
330
- "content": {
331
- "content_type": "text",
332
- "parts": [full_response_for_storage],
333
- "text": full_response_for_storage
334
- },
335
- "status": "finished_successfully",
336
- "timestamp_": datetime.now().timestamp()
337
- },
338
- "artifacts": artifacts
339
- }
340
- }
341
-
342
- yield f"data: {json.dumps(final_payload)}\n\n"
343
-
344
- # Save to Redis
345
- if full_response_for_storage:
346
- add_message(request.user_login_id, session_id, "assistant", full_response_for_storage)
347
- update_session_title_if_needed(request.user_login_id, session_id, is_new_session)
348
-
349
- # Store to database
350
- suffix = "dd" if use_data_analyst else "ss"
351
- db_result = await store_conversation_to_db(
352
- user_login_id=request.user_login_id,
353
- session_id=session_id,
354
- user_query=request.query,
355
- ai_response=full_response_for_storage,
356
- metadata={
357
- "model_used": "gpt-4o",
358
- "processing_time": processing_time,
359
- "search_mode": "sql_analysis" if use_data_analyst else "hybrid"
360
- },
361
- artifacts=artifacts,
362
- suffix=suffix,
363
- convo_ids=convo_ids
364
- )
365
-
366
- # Completion signal
367
- completion_response = {
368
- "event": "stream_complete",
369
- "data": {
370
- "status": "finished_successfully",
371
- "timestamp": datetime.now().timestamp(),
372
- "db_stored": db_result is not None,
373
- "artifacts_count": len(artifacts),
374
- "phases_completed": ["thinking", "answer", "charts"] if artifacts else ["thinking", "answer"]
375
- }
376
- }
377
- yield f"data: {json.dumps(completion_response)}\n\n"
378
- yield "data: [DONE]\n\n"
379
 
380
- except Exception as e:
381
- error_data = {
382
- "status": "error",
383
- "api_version": "v1",
384
- "request_id": request_id,
385
- "event": "error_response",
386
- "data": {
387
- "error": {
388
- "code": "STREAM_ERROR",
389
- "message": str(e),
390
- "details": {"type": "stream_error"}
391
- },
392
- "timestamp": datetime.now().timestamp()
393
- }
394
- }
395
- yield f"data: {json.dumps(error_data)}\n\n"
396
- yield "data: [DONE]\n\n"
397
 
398
- # =====================================
399
- # ENDPOINTS
400
- # =====================================
401
 
402
  @main_chat_router_v6.get("/")
403
  async def root():
404
  return {
405
- "message": "AI Data Analysis Chatbot API with Agent Reasoning",
406
- "version": "2.2.0",
407
- "model": "gpt-4o",
408
- "features": [
409
- "True Agent Reasoning in Thinking Phase",
410
- "Phased Streaming (thinking → answer → charts)",
411
- "Session Management",
412
- "Vega-Lite Artifacts",
413
- "DB Storage"
414
- ],
415
- "phases": {
416
- "thinking": "Agent's step-by-step reasoning process",
417
- "answer": "Final analysis with specific numbers",
418
- "charts": "Vega-Lite visualizations with actual data"
419
- }
420
  }
421
 
 
422
  @main_chat_router_v6.post("/chat/stream", response_class=StreamingResponse)
423
- async def chat_stream(request: ChatRequest):
424
- """Streaming endpoint with agent reasoning phases"""
 
 
 
 
 
425
  try:
426
  is_new_session = False
427
-
428
- if request.session_id is None:
429
- session_data = create_session(request.user_login_id, request.org_id, request.metadata)
430
  session_id = session_data["session_id"]
431
  is_new_session = True
432
  conversation_history = []
433
  else:
 
434
  try:
435
- session_data = get_session(request.user_login_id, request.session_id)
436
- session_id = request.session_id
437
- conversation_history = get_message_history(request.user_login_id, session_id, limit=5)
438
  except HTTPException as e:
439
  if e.status_code == 404:
440
- session_data = create_session(request.user_login_id, request.org_id, request.metadata)
441
  session_id = session_data["session_id"]
442
  is_new_session = True
443
  conversation_history = []
444
  else:
445
  raise
446
 
447
- add_message(request.user_login_id, session_id, "user", request.query)
448
- use_data_analyst = determine_agent_type(request)
 
 
 
 
 
 
449
  convo_ids = await make_convo_id(suffix="dd" if use_data_analyst else "ss")
450
 
451
- return StreamingResponse(
452
- stream_chat_response_phased(
453
- request=request,
454
- session_id=session_id,
455
- is_new_session=is_new_session,
456
- conversation_history=conversation_history,
457
- use_data_analyst=use_data_analyst,
458
- convo_ids=convo_ids
459
- ),
460
- media_type="text/event-stream",
461
- headers={
462
- "Cache-Control": "no-cache",
463
- "Connection": "keep-alive",
464
- "X-Accel-Buffering": "no",
 
 
 
 
 
 
 
465
  }
466
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
 
 
468
  except Exception as e:
469
- raise HTTPException(status_code=500, detail=f"Error in chat stream: {str(e)}")
 
 
1
  """
2
+ main_chat_router_v7.py
3
+ Production-ready FastAPI router with phased SSE streaming.
4
+
5
+ - Uses RobustPhaseDetector from data_agent_v3
6
+ - Streams JSON events with phases: thinking, reasoning, answer, charts, done
7
+ - Handles cancellation and timeouts gracefully
8
  """
9
 
 
 
 
 
10
  import json
 
 
 
11
  import uuid
12
+ import asyncio
13
+ import logging
14
+ from datetime import datetime
15
+ from typing import AsyncGenerator, List, Dict, Optional
16
 
17
+ from fastapi import APIRouter, HTTPException, Request
18
+ from fastapi.responses import StreamingResponse
 
19
 
20
+ # Import the new data agent implementation
21
+ from main_chat_v6.data_agent import *
22
+ from main_chat_v6.multiagent import *
23
+
24
+ # Your session/redis helpers (preserve existing API)
25
+ from Redis.sessions_new import (
26
+ create_session,
27
+ get_session,
28
+ get_message_history,
29
+ add_message,
30
+ update_session_title_if_needed,
31
+ )
32
+
33
+ # DB storage helper
34
+ import httpx
35
  from pydantic import BaseModel
 
 
36
 
37
+ logger = logging.getLogger("main_chat_router_v7")
38
+ logging.basicConfig(level=logging.INFO)
39
+
40
+ main_chat_router_v6 = APIRouter(prefix="/main_chatbot/v7", tags=["main_chatbot_v7"])
41
 
 
 
 
42
 
43
  class Author(BaseModel):
44
  role: str
45
 
46
+
47
  class ChatRequest(BaseModel):
48
  user_login_id: str
49
  session_id: Optional[str] = None
 
55
  create_time: Optional[str] = None
56
  token: Optional[str] = None
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  async def make_convo_id(suffix: str = "ss") -> List[str]:
 
60
  base_uuid = str(uuid.uuid4())
61
  convo_id = base_uuid[:-2] + suffix
62
  return [convo_id]
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ async def store_conversation_to_db(payload: dict, db_api_url: str = "https://mr-mvp-api-dev.dev.ingenspark.com/Db_store_router/conversations") -> Optional[Dict]:
 
 
 
 
 
 
 
 
 
 
 
66
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  async with httpx.AsyncClient(timeout=30.0) as client:
68
+ r = await client.post(db_api_url, json=payload, headers={"accept": "application/json", "Content-Type": "application/json"})
69
+ r.raise_for_status()
70
+ return r.json()
 
 
 
 
71
  except Exception as e:
72
+ logger.exception("Error storing conversation to DB")
73
  return None
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def create_phase_chunk(content: str, phase: str, idx: int = 0) -> dict:
77
+ """Standardized chunk definition for SSE"""
78
+ return {
79
+ "type": "chat:completion",
80
+ "data": {
81
+ "delta_content": content,
82
+ "phase": phase,
83
+ "index": idx,
84
+ "timestamp": datetime.utcnow().isoformat() + "Z",
85
+ },
86
+ }
 
 
 
 
 
 
87
 
 
 
 
88
 
89
  @main_chat_router_v6.get("/")
90
  async def root():
91
  return {
92
+ "message": "AI Data Analysis Chatbot API (v7)",
93
+ "version": "3.0.0",
94
+ "notes": "Production-grade phased streaming (thinking → reasoning → answer → charts → ...)",
 
 
 
 
 
 
 
 
 
 
 
 
95
  }
96
 
97
+
98
  @main_chat_router_v6.post("/chat/stream", response_class=StreamingResponse)
99
+ async def chat_stream(request: Request, payload: ChatRequest):
100
+ """
101
+ Streaming endpoint.
102
+ The SSE stream yields JSON lines of the form: data: { ... }\n\n
103
+ Emitted phases: THINKING, REASONING, ANSWER, CHART, ERROR, DONE
104
+ """
105
+ # Create or resume session
106
  try:
107
  is_new_session = False
108
+ if payload.session_id is None:
109
+ session_data = create_session(payload.user_login_id, payload.org_id, payload.metadata)
 
110
  session_id = session_data["session_id"]
111
  is_new_session = True
112
  conversation_history = []
113
  else:
114
+ session_id = payload.session_id
115
  try:
116
+ session_data = get_session(payload.user_login_id, session_id)
117
+ conversation_history = get_message_history(payload.user_login_id, session_id, limit=5)
 
118
  except HTTPException as e:
119
  if e.status_code == 404:
120
+ session_data = create_session(payload.user_login_id, payload.org_id, payload.metadata)
121
  session_id = session_data["session_id"]
122
  is_new_session = True
123
  conversation_history = []
124
  else:
125
  raise
126
 
127
+ # Persist the user's message in session storage immediately
128
+ add_message(payload.user_login_id, session_id, "user", payload.query)
129
+
130
+ # Decide which streaming generator to call
131
+ use_data_analyst = False
132
+ if payload.metadata and payload.metadata.get("file_paths"):
133
+ use_data_analyst = True
134
+
135
  convo_ids = await make_convo_id(suffix="dd" if use_data_analyst else "ss")
136
 
137
+ # The actual generator that will yield SSE chunks
138
+ async def event_generator() -> AsyncGenerator[str, None]:
139
+ request_id = f"req_{uuid.uuid4().hex[:26].upper()}"
140
+ start_time = datetime.utcnow()
141
+ idx = 0
142
+
143
+ # Initial metadata event
144
+ initial_metadata = {
145
+ "status": "ok",
146
+ "api_version": "v1",
147
+ "request_id": request_id,
148
+ "event": "stream_started",
149
+ "data": {
150
+ "session_id": session_id,
151
+ "user_message_id": str(uuid.uuid4()),
152
+ "ai_message_id": str(uuid.uuid4()),
153
+ "is_new_session": is_new_session,
154
+ "agent_type": "data_analyst" if use_data_analyst else "multiagent",
155
+ "timestamp": datetime.utcnow().timestamp(),
156
+ "userLoginId": payload.user_login_id,
157
+ },
158
  }
159
+ yield f"data: {json.dumps(initial_metadata)}\n\n"
160
+
161
+ # Choose the appropriate stream
162
+ try:
163
+ if use_data_analyst:
164
+ # stream_analysis yields (phase_tag, content) tuples
165
+ file_paths = payload.metadata.get("file_paths", [])
166
+ async for phase_tag, content in stream_analysis(payload.query, file_paths):
167
+ # If client disconnected, generator will be cancelled; check here
168
+ if await request.is_disconnected():
169
+ logger.info("Client disconnected, cancelling stream.")
170
+ break
171
+
172
+ idx += 1
173
+ if phase_tag in ("THINKING", "REASONING", "ANSWER"):
174
+ chunk = create_phase_chunk(content, phase_tag.lower(), idx)
175
+ yield f"data: {json.dumps(chunk)}\n\n"
176
+ elif phase_tag == "CHART":
177
+ chart_event = {
178
+ "type": "chart:artifact",
179
+ "data": {
180
+ "chart_index": idx,
181
+ "vega_spec": content,
182
+ "phase": "charts",
183
+ "timestamp": datetime.utcnow().isoformat() + "Z",
184
+ },
185
+ }
186
+ yield f"data: {json.dumps(chart_event)}\n\n"
187
+ elif phase_tag == "ERROR":
188
+ err = {"status": "error", "event": "error_response", "data": {"error": {"message": content}, "request_id": request_id}}
189
+ yield f"data: {json.dumps(err)}\n\n"
190
+ elif phase_tag == "DONE":
191
+ break
192
+
193
+ else:
194
+ # Multiagent streaming: expected to yield raw chunks or control tokens
195
+ async for chunk in stream_multiagent(payload.query):
196
+ if await request.is_disconnected():
197
+ logger.info("Client disconnected, cancelling multiagent stream.")
198
+ break
199
+ idx += 1
200
+ if chunk == "[DONE]":
201
+ break
202
+ elif chunk.startswith("[ERROR]"):
203
+ err = {"status": "error", "event": "error_response", "data": {"error": {"message": chunk[7:]}, "request_id": request_id}}
204
+ yield f"data: {json.dumps(err)}\n\n"
205
+ else:
206
+ chunk_obj = create_phase_chunk(chunk, "answer", idx)
207
+ yield f"data: {json.dumps(chunk_obj)}\n\n"
208
+
209
+ # Final payload: store and send completion notification
210
+ processing_time = (datetime.utcnow() - start_time).total_seconds()
211
+ # Compose final payload minimally; you can enrich and store as needed
212
+ completion_response = {
213
+ "event": "stream_complete",
214
+ "data": {
215
+ "status": "finished_successfully",
216
+ "timestamp": datetime.utcnow().timestamp(),
217
+ "processing_time": processing_time,
218
+ "convo_id": convo_ids[0],
219
+ },
220
+ }
221
+ yield f"data: {json.dumps(completion_response)}\n\n"
222
+ yield "data: [DONE]\n\n"
223
+
224
+ except asyncio.CancelledError:
225
+ logger.info("Stream cancelled by client or server; performing cleanup.")
226
+ # Do any cleanup: close DB connections, rollback, etc.
227
+ raise
228
+ except Exception as exc:
229
+ logger.exception("Unhandled exception in event_generator")
230
+ error_payload = {"status": "error", "event": "error_response", "data": {"error": {"message": str(exc)}, "request_id": request_id}}
231
+ yield f"data: {json.dumps(error_payload)}\n\n"
232
+ yield "data: [DONE]\n\n"
233
 
234
+ # Return StreamingResponse
235
+ return StreamingResponse(event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"})
236
  except Exception as e:
237
+ logger.exception("chat_stream top-level error")
238
+ raise HTTPException(status_code=500, detail=f"Error in chat stream: {str(e)}")
main_chat_v6/rough.ipynb ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cf14a6eb",
6
+ "metadata": {},
7
+ "source": [
8
+ "main_chart_router"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "9abfde57",
15
+ "metadata": {
16
+ "vscode": {
17
+ "languageId": "plaintext"
18
+ }
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "\"\"\"\n",
23
+ "Main Router with True Agent Reasoning in Thinking Phase\n",
24
+ "Phases: thinking (agent's reasoning), answer (analysis), charts (visualizations)\n",
25
+ "\"\"\"\n",
26
+ "\n",
27
+ "from fastapi import APIRouter, HTTPException\n",
28
+ "from fastapi.responses import StreamingResponse\n",
29
+ "from datetime import datetime\n",
30
+ "from typing import AsyncGenerator, List, Optional, Dict, Tuple\n",
31
+ "import json\n",
32
+ "import asyncio\n",
33
+ "import os\n",
34
+ "import sys\n",
35
+ "import uuid\n",
36
+ "from pathlib import Path\n",
37
+ "import httpx\n",
38
+ "import re\n",
39
+ "\n",
40
+ "PROJECT_ROOT = Path(__file__).resolve().parents[2]\n",
41
+ "if str(PROJECT_ROOT) not in sys.path:\n",
42
+ " sys.path.insert(0, str(PROJECT_ROOT))\n",
43
+ "\n",
44
+ "from Redis.sessions_new import *\n",
45
+ "from main_chat_v6.data_agent import stream_analysis\n",
46
+ "from main_chat_v6.multiagent import stream_multiagent\n",
47
+ "from pydantic import BaseModel\n",
48
+ "from dotenv import load_dotenv\n",
49
+ "from retrieve_secret import *\n",
50
+ "\n",
51
+ "load_dotenv()\n",
52
+ "\n",
53
+ "# =====================================\n",
54
+ "# MODELS\n",
55
+ "# =====================================\n",
56
+ "\n",
57
+ "class Author(BaseModel):\n",
58
+ " role: str\n",
59
+ "\n",
60
+ "class ChatRequest(BaseModel):\n",
61
+ " user_login_id: str\n",
62
+ " session_id: Optional[str] = None\n",
63
+ " query: str\n",
64
+ " org_id: Optional[str] = None\n",
65
+ " metadata: Optional[Dict] = None\n",
66
+ " list_of_files_path: Optional[List[str]] = None\n",
67
+ " author: Optional[Author] = None\n",
68
+ " create_time: Optional[str] = None\n",
69
+ " token: Optional[str] = None\n",
70
+ "\n",
71
+ "class ChatResponse(BaseModel):\n",
72
+ " session_id: str\n",
73
+ " user_message: str\n",
74
+ " assistant_response: Dict\n",
75
+ " is_new_session: bool\n",
76
+ " session_title: str\n",
77
+ " timestamp: str\n",
78
+ " files_processed: Optional[List[str]] = None\n",
79
+ " artifacts_paths: Optional[List[str]] = None\n",
80
+ "\n",
81
+ "main_chat_router_v6 = APIRouter(prefix=\"/main_chatbot/v6\", tags=[\"main_chatbot_v6\"])\n",
82
+ "\n",
83
+ "# =====================================\n",
84
+ "# HELPER FUNCTIONS\n",
85
+ "# =====================================\n",
86
+ "\n",
87
+ "async def make_convo_id(suffix: str = \"ss\") -> List[str]:\n",
88
+ " \"\"\"Generate UUID with suffix in mutable list\"\"\"\n",
89
+ " base_uuid = str(uuid.uuid4())\n",
90
+ " convo_id = base_uuid[:-2] + suffix\n",
91
+ " return [convo_id]\n",
92
+ "\n",
93
+ "def determine_agent_type(request: ChatRequest) -> bool:\n",
94
+ " \"\"\"Determine which agent to use. Returns True for data_analyst.\"\"\"\n",
95
+ " if request.metadata and \"file_paths\" in request.metadata:\n",
96
+ " file_paths = request.metadata[\"file_paths\"]\n",
97
+ " if file_paths and len(file_paths) > 0:\n",
98
+ " return True\n",
99
+ " \n",
100
+ " query_lower = request.query.lower()\n",
101
+ " data_keywords = [\n",
102
+ " \"analyze\", \"analysis\", \"dataset\", \"csv\", \"data\", \"statistics\",\n",
103
+ " \"visualize\", \"visualization\", \"chart\", \"graph\", \"plot\",\n",
104
+ " \"sql\", \"query\", \"select\", \"table\", \"column\", \"row\"\n",
105
+ " ]\n",
106
+ " \n",
107
+ " return any(keyword in query_lower for keyword in data_keywords)\n",
108
+ "\n",
109
+ "def create_phase_chunk(content: str, phase: str) -> dict:\n",
110
+ " \"\"\"Create standardized phase chunk\"\"\"\n",
111
+ " return {\n",
112
+ " \"type\": \"chat:completion\",\n",
113
+ " \"data\": {\n",
114
+ " \"delta_content\": content,\n",
115
+ " \"phase\": phase\n",
116
+ " }\n",
117
+ " }\n",
118
+ "\n",
119
+ "async def store_conversation_to_db(\n",
120
+ " user_login_id: str,\n",
121
+ " session_id: str,\n",
122
+ " user_query: str,\n",
123
+ " ai_response: str,\n",
124
+ " metadata: Dict[str, any],\n",
125
+ " artifacts: list = None,\n",
126
+ " suffix: str = \"ss\",\n",
127
+ " convo_ids: List[str] = None,\n",
128
+ " db_api_url: str = \"https://mr-mvp-api-dev.dev.ingenspark.com/Db_store_router/conversations\"\n",
129
+ ") -> Optional[Dict]:\n",
130
+ " \"\"\"Store conversation to database API\"\"\"\n",
131
+ " try:\n",
132
+ " if convo_ids is None or not convo_ids:\n",
133
+ " convo_ids = await make_convo_id(suffix)\n",
134
+ " \n",
135
+ " convo_id = convo_ids[0]\n",
136
+ " query_id = f\"q_{uuid.uuid4()}\"\n",
137
+ " \n",
138
+ " payload = {\n",
139
+ " \"user_id\": user_login_id,\n",
140
+ " \"convo_id\": convo_id,\n",
141
+ " \"data\": {\n",
142
+ " \"user_query\": {\n",
143
+ " \"query_id\": query_id,\n",
144
+ " \"text\": user_query,\n",
145
+ " \"user_metadata\": {\n",
146
+ " \"location\": metadata.get(\"user\", {}).get(\"location\", \"Unknown\"),\n",
147
+ " \"language\": metadata.get(\"user\", {}).get(\"language\", \"English\")\n",
148
+ " }\n",
149
+ " },\n",
150
+ " \"response\": {\n",
151
+ " \"text\": ai_response,\n",
152
+ " \"status\": \"success\",\n",
153
+ " \"response_time\": datetime.now().isoformat() + \"Z\",\n",
154
+ " \"duration\": f\"{metadata.get('processing_time', 0):.2f}s\",\n",
155
+ " \"artifacts\": artifacts if artifacts else []\n",
156
+ " },\n",
157
+ " \"metadata\": {\n",
158
+ " \"processing_node\": {\n",
159
+ " \"model_version\": metadata.get(\"model_used\", \"gpt-4o\"),\n",
160
+ " \"api_latency_ms\": int(metadata.get(\"processing_time\", 0) * 1000),\n",
161
+ " \"search_mode\": metadata.get(\"search_mode\", \"sql_analysis\")\n",
162
+ " }\n",
163
+ " }\n",
164
+ " },\n",
165
+ " \"is_saved\": False\n",
166
+ " }\n",
167
+ " \n",
168
+ " async with httpx.AsyncClient(timeout=30.0) as client:\n",
169
+ " response = await client.post(db_api_url, json=payload, headers={\n",
170
+ " \"accept\": \"application/json\",\n",
171
+ " \"Content-Type\": \"application/json\"\n",
172
+ " })\n",
173
+ " response.raise_for_status()\n",
174
+ " return response.json()\n",
175
+ " \n",
176
+ " except Exception as e:\n",
177
+ " print(f\"Error storing conversation to DB: {e}\")\n",
178
+ " return None\n",
179
+ "\n",
180
+ "# =====================================\n",
181
+ "# PHASED STREAMING WITH AGENT REASONING\n",
182
+ "# =====================================\n",
183
+ "\n",
184
+ "async def stream_chat_response_phased(\n",
185
+ " request: ChatRequest,\n",
186
+ " session_id: str,\n",
187
+ " is_new_session: bool,\n",
188
+ " conversation_history: List,\n",
189
+ " use_data_analyst: bool,\n",
190
+ " convo_ids: List[str]\n",
191
+ ") -> AsyncGenerator[str, None]:\n",
192
+ " \"\"\"\n",
193
+ " Stream response with phases from agent:\n",
194
+ " - thinking: Agent's actual reasoning process\n",
195
+ " - answer: Analysis and narrative\n",
196
+ " - charts: Visualizations\n",
197
+ " \"\"\"\n",
198
+ " full_response = \"\"\n",
199
+ " artifacts = []\n",
200
+ " user_message_id = str(uuid.uuid4())\n",
201
+ " ai_message_id = str(uuid.uuid4())\n",
202
+ " request_id = f\"req_{uuid.uuid4().hex[:26].upper()}\"\n",
203
+ " start_time = datetime.now()\n",
204
+ " \n",
205
+ " thinking_buffer = \"\"\n",
206
+ " answer_buffer = \"\"\n",
207
+ " \n",
208
+ " try:\n",
209
+ " # ===== INITIAL METADATA =====\n",
210
+ " initial_metadata = {\n",
211
+ " \"status\": \"ok\",\n",
212
+ " \"api_version\": \"v1\",\n",
213
+ " \"request_id\": request_id,\n",
214
+ " \"event\": \"stream_started\",\n",
215
+ " \"data\": {\n",
216
+ " \"session_id\": session_id,\n",
217
+ " \"user_message_id\": user_message_id,\n",
218
+ " \"ai_message_id\": ai_message_id,\n",
219
+ " \"is_new_session\": is_new_session,\n",
220
+ " \"agent_type\": \"data_analyst_gpt4o\" if use_data_analyst else \"multiagent\",\n",
221
+ " \"timestamp\": datetime.now().timestamp(),\n",
222
+ " \"userLoginId\": request.user_login_id\n",
223
+ " }\n",
224
+ " }\n",
225
+ " yield f\"data: {json.dumps(initial_metadata)}\\n\\n\"\n",
226
+ " await asyncio.sleep(0)\n",
227
+ "\n",
228
+ " # ===== STREAM FROM AGENT =====\n",
229
+ " if use_data_analyst:\n",
230
+ " file_paths = request.metadata.get(\"file_paths\", []) if request.metadata else []\n",
231
+ " \n",
232
+ " if not file_paths:\n",
233
+ " error_msg = \"Data analyst requires file_paths in metadata\"\n",
234
+ " yield f\"data: {json.dumps({'error': error_msg})}\\n\\n\"\n",
235
+ " yield \"data: [DONE]\\n\\n\"\n",
236
+ " return\n",
237
+ " \n",
238
+ " chart_count = 0\n",
239
+ " \n",
240
+ " # Stream from data analyst agent\n",
241
+ " async for phase, content in stream_analysis(request.query, file_paths):\n",
242
+ " if phase == \"DONE\":\n",
243
+ " break\n",
244
+ " \n",
245
+ " elif phase == \"ERROR\":\n",
246
+ " error_response = {\n",
247
+ " \"status\": \"error\",\n",
248
+ " \"event\": \"error_response\",\n",
249
+ " \"data\": {\n",
250
+ " \"error\": {\n",
251
+ " \"code\": \"STREAM_ERROR\",\n",
252
+ " \"message\": content,\n",
253
+ " \"details\": {}\n",
254
+ " },\n",
255
+ " \"request_id\": request_id,\n",
256
+ " \"timestamp\": datetime.now().timestamp()\n",
257
+ " }\n",
258
+ " }\n",
259
+ " yield f\"data: {json.dumps(error_response)}\\n\\n\"\n",
260
+ " \n",
261
+ " elif phase == \"THINKING\":\n",
262
+ " # Agent's reasoning - stream as thinking phase\n",
263
+ " thinking_buffer += content\n",
264
+ " phase_chunk = create_phase_chunk(content, \"thinking\")\n",
265
+ " yield f\"data: {json.dumps(phase_chunk)}\\n\\n\"\n",
266
+ " await asyncio.sleep(0)\n",
267
+ " \n",
268
+ " elif phase == \"ANSWER\":\n",
269
+ " # Agent's analysis - stream as answer phase\n",
270
+ " answer_buffer += content\n",
271
+ " full_response += content\n",
272
+ " phase_chunk = create_phase_chunk(content, \"answer\")\n",
273
+ " yield f\"data: {json.dumps(phase_chunk)}\\n\\n\"\n",
274
+ " await asyncio.sleep(0)\n",
275
+ " \n",
276
+ " elif phase == \"CHART\":\n",
277
+ " # Visualization artifact\n",
278
+ " chart_count += 1\n",
279
+ " artifacts.append(content)\n",
280
+ " \n",
281
+ " chart_event = {\n",
282
+ " \"type\": \"chart:artifact\",\n",
283
+ " \"data\": {\n",
284
+ " \"chart_index\": chart_count,\n",
285
+ " \"vega_spec\": content,\n",
286
+ " \"phase\": \"charts\"\n",
287
+ " }\n",
288
+ " }\n",
289
+ " yield f\"data: {json.dumps(chart_event)}\\n\\n\"\n",
290
+ " await asyncio.sleep(0.1)\n",
291
+ " \n",
292
+ " else:\n",
293
+ " # Multiagent flow (if needed)\n",
294
+ " async for chunk in stream_multiagent(request.query):\n",
295
+ " if chunk == \"[DONE]\":\n",
296
+ " break\n",
297
+ " elif chunk.startswith(\"[ERROR]\"):\n",
298
+ " error_response = {\n",
299
+ " \"status\": \"error\",\n",
300
+ " \"event\": \"error_response\",\n",
301
+ " \"data\": {\n",
302
+ " \"error\": {\n",
303
+ " \"code\": \"STREAM_ERROR\",\n",
304
+ " \"message\": chunk[7:],\n",
305
+ " \"details\": {}\n",
306
+ " },\n",
307
+ " \"request_id\": request_id,\n",
308
+ " \"timestamp\": datetime.now().timestamp()\n",
309
+ " }\n",
310
+ " }\n",
311
+ " yield f\"data: {json.dumps(error_response)}\\n\\n\"\n",
312
+ " else:\n",
313
+ " full_response += chunk\n",
314
+ " phase_chunk = create_phase_chunk(chunk, \"answer\")\n",
315
+ " yield f\"data: {json.dumps(phase_chunk)}\\n\\n\"\n",
316
+ " await asyncio.sleep(0)\n",
317
+ "\n",
318
+ " # ===== FINAL PAYLOAD =====\n",
319
+ " processing_time = (datetime.now() - start_time).total_seconds()\n",
320
+ " session_data = get_session(request.user_login_id, session_id)\n",
321
+ " \n",
322
+ " # Combine thinking + answer for full response storage\n",
323
+ " full_response_for_storage = thinking_buffer + \"\\n\\n\" + answer_buffer\n",
324
+ " \n",
325
+ " final_payload = {\n",
326
+ " \"status\": \"ok\",\n",
327
+ " \"api_version\": \"v1\",\n",
328
+ " \"request_id\": request_id,\n",
329
+ " \"event\": \"message_response\",\n",
330
+ " \"data\": {\n",
331
+ " \"type\": \"assistant_message\",\n",
332
+ " \"conversation_id\": convo_ids[0],\n",
333
+ " \"session_id\": session_id,\n",
334
+ " \"user_message_id\": user_message_id,\n",
335
+ " \"ai_message_id\": ai_message_id,\n",
336
+ " \"userLoginId\": request.user_login_id,\n",
337
+ " \"agent_used\": \"data_analyst_gpt4o\" if use_data_analyst else \"multiagent\",\n",
338
+ " \"timestamp\": datetime.now().timestamp(),\n",
339
+ " \"title\": session_data.get(\"title\", \"New Chat\"),\n",
340
+ " \"total_messages\": len(get_message_history(request.user_login_id, session_id)),\n",
341
+ " \"message\": {\n",
342
+ " \"id\": ai_message_id,\n",
343
+ " \"author\": {\"role\": \"assistant\", \"name\": \"DataStory AI\"},\n",
344
+ " \"metadata\": {\n",
345
+ " \"model_used\": \"gpt-4o\",\n",
346
+ " \"processing_time\": processing_time,\n",
347
+ " \"thinking_length\": len(thinking_buffer),\n",
348
+ " \"answer_length\": len(answer_buffer),\n",
349
+ " \"charts_generated\": len(artifacts)\n",
350
+ " },\n",
351
+ " \"content\": {\n",
352
+ " \"content_type\": \"text\",\n",
353
+ " \"parts\": [full_response_for_storage],\n",
354
+ " \"text\": full_response_for_storage\n",
355
+ " },\n",
356
+ " \"status\": \"finished_successfully\",\n",
357
+ " \"timestamp_\": datetime.now().timestamp()\n",
358
+ " },\n",
359
+ " \"artifacts\": artifacts\n",
360
+ " }\n",
361
+ " }\n",
362
+ " \n",
363
+ " yield f\"data: {json.dumps(final_payload)}\\n\\n\"\n",
364
+ " \n",
365
+ " # Save to Redis\n",
366
+ " if full_response_for_storage:\n",
367
+ " add_message(request.user_login_id, session_id, \"assistant\", full_response_for_storage)\n",
368
+ " update_session_title_if_needed(request.user_login_id, session_id, is_new_session)\n",
369
+ " \n",
370
+ " # Store to database\n",
371
+ " suffix = \"dd\" if use_data_analyst else \"ss\"\n",
372
+ " db_result = await store_conversation_to_db(\n",
373
+ " user_login_id=request.user_login_id,\n",
374
+ " session_id=session_id,\n",
375
+ " user_query=request.query,\n",
376
+ " ai_response=full_response_for_storage,\n",
377
+ " metadata={\n",
378
+ " \"model_used\": \"gpt-4o\",\n",
379
+ " \"processing_time\": processing_time,\n",
380
+ " \"search_mode\": \"sql_analysis\" if use_data_analyst else \"hybrid\"\n",
381
+ " },\n",
382
+ " artifacts=artifacts,\n",
383
+ " suffix=suffix,\n",
384
+ " convo_ids=convo_ids\n",
385
+ " )\n",
386
+ " \n",
387
+ " # Completion signal\n",
388
+ " completion_response = {\n",
389
+ " \"event\": \"stream_complete\",\n",
390
+ " \"data\": {\n",
391
+ " \"status\": \"finished_successfully\",\n",
392
+ " \"timestamp\": datetime.now().timestamp(),\n",
393
+ " \"db_stored\": db_result is not None,\n",
394
+ " \"artifacts_count\": len(artifacts),\n",
395
+ " \"phases_completed\": [\"thinking\", \"answer\", \"charts\"] if artifacts else [\"thinking\", \"answer\"]\n",
396
+ " }\n",
397
+ " }\n",
398
+ " yield f\"data: {json.dumps(completion_response)}\\n\\n\"\n",
399
+ " yield \"data: [DONE]\\n\\n\"\n",
400
+ "\n",
401
+ " except Exception as e:\n",
402
+ " error_data = {\n",
403
+ " \"status\": \"error\",\n",
404
+ " \"api_version\": \"v1\",\n",
405
+ " \"request_id\": request_id,\n",
406
+ " \"event\": \"error_response\",\n",
407
+ " \"data\": {\n",
408
+ " \"error\": {\n",
409
+ " \"code\": \"STREAM_ERROR\",\n",
410
+ " \"message\": str(e),\n",
411
+ " \"details\": {\"type\": \"stream_error\"}\n",
412
+ " },\n",
413
+ " \"timestamp\": datetime.now().timestamp()\n",
414
+ " }\n",
415
+ " }\n",
416
+ " yield f\"data: {json.dumps(error_data)}\\n\\n\"\n",
417
+ " yield \"data: [DONE]\\n\\n\"\n",
418
+ "\n",
419
+ "# =====================================\n",
420
+ "# ENDPOINTS\n",
421
+ "# =====================================\n",
422
+ "\n",
423
+ "@main_chat_router_v6.get(\"/\")\n",
424
+ "async def root():\n",
425
+ " return {\n",
426
+ " \"message\": \"AI Data Analysis Chatbot API with Agent Reasoning\",\n",
427
+ " \"version\": \"2.2.0\",\n",
428
+ " \"model\": \"gpt-4o\",\n",
429
+ " \"features\": [\n",
430
+ " \"True Agent Reasoning in Thinking Phase\",\n",
431
+ " \"Phased Streaming (thinking → answer → charts)\",\n",
432
+ " \"Session Management\",\n",
433
+ " \"Vega-Lite Artifacts\",\n",
434
+ " \"DB Storage\"\n",
435
+ " ],\n",
436
+ " \"phases\": {\n",
437
+ " \"thinking\": \"Agent's step-by-step reasoning process\",\n",
438
+ " \"answer\": \"Final analysis with specific numbers\",\n",
439
+ " \"charts\": \"Vega-Lite visualizations with actual data\"\n",
440
+ " }\n",
441
+ " }\n",
442
+ "\n",
443
+ "@main_chat_router_v6.post(\"/chat/stream\", response_class=StreamingResponse)\n",
444
+ "async def chat_stream(request: ChatRequest):\n",
445
+ " \"\"\"Streaming endpoint with agent reasoning phases\"\"\"\n",
446
+ " try:\n",
447
+ " is_new_session = False\n",
448
+ " \n",
449
+ " if request.session_id is None:\n",
450
+ " session_data = create_session(request.user_login_id, request.org_id, request.metadata)\n",
451
+ " session_id = session_data[\"session_id\"]\n",
452
+ " is_new_session = True\n",
453
+ " conversation_history = []\n",
454
+ " else:\n",
455
+ " try:\n",
456
+ " session_data = get_session(request.user_login_id, request.session_id)\n",
457
+ " session_id = request.session_id\n",
458
+ " conversation_history = get_message_history(request.user_login_id, session_id, limit=5)\n",
459
+ " except HTTPException as e:\n",
460
+ " if e.status_code == 404:\n",
461
+ " session_data = create_session(request.user_login_id, request.org_id, request.metadata)\n",
462
+ " session_id = session_data[\"session_id\"]\n",
463
+ " is_new_session = True\n",
464
+ " conversation_history = []\n",
465
+ " else:\n",
466
+ " raise\n",
467
+ "\n",
468
+ " add_message(request.user_login_id, session_id, \"user\", request.query)\n",
469
+ " use_data_analyst = determine_agent_type(request)\n",
470
+ " convo_ids = await make_convo_id(suffix=\"dd\" if use_data_analyst else \"ss\")\n",
471
+ "\n",
472
+ " return StreamingResponse(\n",
473
+ " stream_chat_response_phased(\n",
474
+ " request=request,\n",
475
+ " session_id=session_id,\n",
476
+ " is_new_session=is_new_session,\n",
477
+ " conversation_history=conversation_history,\n",
478
+ " use_data_analyst=use_data_analyst,\n",
479
+ " convo_ids=convo_ids\n",
480
+ " ),\n",
481
+ " media_type=\"text/event-stream\",\n",
482
+ " headers={\n",
483
+ " \"Cache-Control\": \"no-cache\",\n",
484
+ " \"Connection\": \"keep-alive\",\n",
485
+ " \"X-Accel-Buffering\": \"no\",\n",
486
+ " }\n",
487
+ " )\n",
488
+ "\n",
489
+ " except Exception as e:\n",
490
+ " raise HTTPException(status_code=500, detail=f\"Error in chat stream: {str(e)}\")"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "markdown",
495
+ "id": "e3406210",
496
+ "metadata": {},
497
+ "source": [
498
+ "data agent"
499
+ ]
500
+ },
501
+ {
502
+ "cell_type": "code",
503
+ "execution_count": null,
504
+ "id": "809c9222",
505
+ "metadata": {
506
+ "vscode": {
507
+ "languageId": "plaintext"
508
+ }
509
+ },
510
+ "outputs": [],
511
+ "source": [
512
+ "\"\"\"\n",
513
+ "Improved Data Agent with Real-time Phased Streaming\n",
514
+ "Phases: thinking (agent reasoning), answer (final response), charts (visualizations)\n",
515
+ "\"\"\"\n",
516
+ "\n",
517
+ "import asyncio\n",
518
+ "import json\n",
519
+ "import re\n",
520
+ "from typing import AsyncGenerator, List, Tuple\n",
521
+ "import duckdb\n",
522
+ "import pandas as pd\n",
523
+ "from openai.types.responses import ResponseTextDeltaEvent\n",
524
+ "from agents import Agent, Runner, function_tool\n",
525
+ "from s3.read_files import read_parquet_from_s3\n",
526
+ "\n",
527
+ "# =====================================\n",
528
+ "# CONFIGURATION\n",
529
+ "# =====================================\n",
530
+ "\n",
531
+ "ALLOWED_SQL_KEYWORDS = (\"select\", \"with\")\n",
532
+ "MAX_RESULT_ROWS = 10000\n",
533
+ "\n",
534
+ "_registered_datasets = {}\n",
535
+ "last_sql_query: str = \"\"\n",
536
+ "\n",
537
+ "# =====================================\n",
538
+ "# DATASET LOADING\n",
539
+ "# =====================================\n",
540
+ "\n",
541
+ "def generate_dataset_context(df: pd.DataFrame, table_name: str) -> str:\n",
542
+ " \"\"\"Generate comprehensive dataset context\"\"\"\n",
543
+ " context_parts = []\n",
544
+ " context_parts.append(f\"Table: {table_name}\")\n",
545
+ " context_parts.append(f\"Rows: {len(df)}, Columns: {len(df.columns)}\")\n",
546
+ " context_parts.append(f\"Column Names: {', '.join(df.columns.tolist())}\")\n",
547
+ " context_parts.append(\"\\nData Types:\")\n",
548
+ " for col, dtype in df.dtypes.items():\n",
549
+ " context_parts.append(f\" - {col}: {dtype}\")\n",
550
+ " \n",
551
+ " context_parts.append(\"\\nSample Data (first 3 rows):\")\n",
552
+ " sample_str = df.head(3).to_string(index=False, max_cols=10)\n",
553
+ " context_parts.append(sample_str)\n",
554
+ " \n",
555
+ " numeric_cols = df.select_dtypes(include=['number']).columns.tolist()\n",
556
+ " if numeric_cols:\n",
557
+ " context_parts.append(\"\\nNumeric Column Statistics:\")\n",
558
+ " desc = df[numeric_cols].describe().transpose()\n",
559
+ " for col in numeric_cols[:5]:\n",
560
+ " if col in desc.index:\n",
561
+ " context_parts.append(\n",
562
+ " f\" - {col}: mean={desc.loc[col, 'mean']:.2f}, \"\n",
563
+ " f\"min={desc.loc[col, 'min']:.2f}, max={desc.loc[col, 'max']:.2f}\"\n",
564
+ " )\n",
565
+ " \n",
566
+ " return \"\\n\".join(context_parts)\n",
567
+ "\n",
568
+ "def load_and_register_datasets(file_paths: List[str]) -> str:\n",
569
+ " \"\"\"Load datasets and register them for SQL queries\"\"\"\n",
570
+ " global _registered_datasets\n",
571
+ " _registered_datasets.clear()\n",
572
+ " context_parts = [\"Available datasets for SQL queries:\\n\"]\n",
573
+ " \n",
574
+ " for idx, file_path in enumerate(file_paths, start=1):\n",
575
+ " table_name = f\"CSV{idx}\"\n",
576
+ " try:\n",
577
+ " print(f\"Loading {file_path} as {table_name}\")\n",
578
+ " df = read_parquet_from_s3(file_path)\n",
579
+ " _registered_datasets[table_name] = df\n",
580
+ " dataset_context = generate_dataset_context(df, table_name)\n",
581
+ " context_parts.append(f\"\\n{dataset_context}\\n\")\n",
582
+ " context_parts.append(\"=\"*70)\n",
583
+ " print(f\"Loaded {table_name}: {len(df)} rows, {len(df.columns)} columns\")\n",
584
+ " except Exception as e:\n",
585
+ " print(f\"Error loading {file_path}: {str(e)}\")\n",
586
+ " raise\n",
587
+ " \n",
588
+ " return \"\\n\".join(context_parts)\n",
589
+ "\n",
590
+ "# =====================================\n",
591
+ "# SQL TOOL\n",
592
+ "# =====================================\n",
593
+ "\n",
594
+ "@function_tool\n",
595
+ "async def sql_tool(sql_query: str, user_query: str = \"\") -> dict:\n",
596
+ " \"\"\"Execute SQL query and return results\"\"\"\n",
597
+ " global _registered_datasets, last_sql_query\n",
598
+ " last_sql_query = sql_query\n",
599
+ "\n",
600
+ " print(f\"\\n{'='*70}\")\n",
601
+ " print(f\"Executing SQL:\\n{sql_query}\")\n",
602
+ " print(f\"{'='*70}\\n\")\n",
603
+ "\n",
604
+ " normalized_query = sql_query.strip().lower()\n",
605
+ " if not normalized_query.startswith(ALLOWED_SQL_KEYWORDS):\n",
606
+ " raise ValueError(\n",
607
+ " f\"Only {', '.join(k.upper() for k in ALLOWED_SQL_KEYWORDS)} queries allowed.\"\n",
608
+ " )\n",
609
+ "\n",
610
+ " if not _registered_datasets:\n",
611
+ " raise ValueError(\"No datasets loaded.\")\n",
612
+ "\n",
613
+ " conn = duckdb.connect(\":memory:\")\n",
614
+ " try:\n",
615
+ " for name, df in _registered_datasets.items():\n",
616
+ " conn.register(name, df)\n",
617
+ " result_df = conn.execute(sql_query).df()\n",
618
+ " except Exception as e:\n",
619
+ " conn.close()\n",
620
+ " raise ValueError(f\"SQL execution failed: {str(e)}\")\n",
621
+ " finally:\n",
622
+ " conn.close()\n",
623
+ "\n",
624
+ " if len(result_df) > MAX_RESULT_ROWS:\n",
625
+ " result_df = result_df.head(MAX_RESULT_ROWS)\n",
626
+ "\n",
627
+ " summary_lines = [\n",
628
+ " f\"Returned {len(result_df)} rows, {len(result_df.columns)} columns\"\n",
629
+ " ]\n",
630
+ " if not result_df.empty:\n",
631
+ " summary_lines.append(\"\\nPreview (first 10 rows):\")\n",
632
+ " summary_lines.append(result_df.head(10).to_string(index=False, max_rows=10))\n",
633
+ " \n",
634
+ " result_summary = \"\\n\".join(summary_lines)\n",
635
+ "\n",
636
+ " print(f\"SQL executed successfully\\n\")\n",
637
+ " return {\n",
638
+ " \"sql_query\": sql_query,\n",
639
+ " \"result_summary\": result_summary,\n",
640
+ " \"columns\": result_df.columns.tolist(),\n",
641
+ " \"row_count\": len(result_df)\n",
642
+ " }\n",
643
+ "\n",
644
+ "# =====================================\n",
645
+ "# DATA ANALYST AGENT WITH STRUCTURED OUTPUT\n",
646
+ "# =====================================\n",
647
+ "\n",
648
+ "def create_data_analyst_agent(dataset_context: str) -> Agent:\n",
649
+ " \"\"\"Create data analyst agent with explicit phase instructions\"\"\"\n",
650
+ " return Agent(\n",
651
+ " name=\"DataStory_AI\",\n",
652
+ " model=\"gpt-4o\",\n",
653
+ " tools=[sql_tool],\n",
654
+ " instructions=f\"\"\"You are **DataStory AI**, an elite data analyst who combines deep analytical reasoning with compelling storytelling.\n",
655
+ "\n",
656
+ "DATASET CONTEXT:\n",
657
+ "{dataset_context}\n",
658
+ "\n",
659
+ "===== CRITICAL: STRUCTURED RESPONSE FORMAT =====\n",
660
+ "\n",
661
+ "You MUST structure your response in THREE distinct phases using XML tags:\n",
662
+ "\n",
663
+ "**PHASE 1: THINKING (Your Reasoning Process)**\n",
664
+ "<thinking>\n",
665
+ "- Explain your analytical approach step-by-step\n",
666
+ "- Describe what SQL queries you'll run and WHY\n",
667
+ "- Outline the insights you're looking for\n",
668
+ "- Discuss any data quality considerations\n",
669
+ "- Map out your visualization strategy\n",
670
+ "Example:\n",
671
+ "\"To answer this question, I need to:\n",
672
+ "1. First examine the distribution of sales across regions by querying CSV1\n",
673
+ "2. Calculate year-over-year growth rates to identify trends\n",
674
+ "3. Identify top performers and outliers\n",
675
+ "4. Create visualizations showing: (a) regional comparison, (b) time series trend, (c) top 10 performers\"\n",
676
+ "</thinking>\n",
677
+ "\n",
678
+ "**PHASE 2: ANSWER (Your Analysis & Narrative)**\n",
679
+ "<answer>\n",
680
+ "Write your complete data story here:\n",
681
+ "- Present key findings with specific numbers\n",
682
+ "- Provide context and business implications \n",
683
+ "- Explain patterns, trends, and anomalies\n",
684
+ "- Use clear section headers\n",
685
+ "- Be specific: \"Sales increased by 23.5% in Q3\" not \"Sales increased significantly\"\n",
686
+ "</answer>\n",
687
+ "\n",
688
+ "**PHASE 3: CHARTS (Visualizations)**\n",
689
+ "For each visualization, use this EXACT format:\n",
690
+ "\n",
691
+ "<datastory_chart>\n",
692
+ "<chart_title>Descriptive Title</chart_title>\n",
693
+ "<chart_description>What this chart shows with specific numbers</chart_description>\n",
694
+ "<vega_spec>\n",
695
+ "{{\n",
696
+ " \"$schema\": \"https://vega.github.io/schema/vega-lite/v5.json\",\n",
697
+ " \"title\": {{\"text\": \"Chart Title\", \"fontSize\": 18, \"fontWeight\": \"bold\"}},\n",
698
+ " \"width\": 700,\n",
699
+ " \"height\": 400,\n",
700
+ " \"data\": {{\"values\": [\n",
701
+ " {{\"category\": \"A\", \"value\": 100}},\n",
702
+ " {{\"category\": \"B\", \"value\": 150}}\n",
703
+ " ]}},\n",
704
+ " \"mark\": {{\"type\": \"bar\", \"tooltip\": true, \"cornerRadius\": 4}},\n",
705
+ " \"encoding\": {{\n",
706
+ " \"x\": {{\"field\": \"category\", \"type\": \"nominal\", \"axis\": {{\"labelAngle\": 0}}}},\n",
707
+ " \"y\": {{\"field\": \"value\", \"type\": \"quantitative\", \"axis\": {{\"title\": \"Value\"}}}}\n",
708
+ " }},\n",
709
+ " \"config\": {{\n",
710
+ " \"axis\": {{\"labelFontSize\": 12, \"titleFontSize\": 14}},\n",
711
+ " \"legend\": {{\"labelFontSize\": 12, \"titleFontSize\": 14}}\n",
712
+ " }}\n",
713
+ "}}\n",
714
+ "</vega_spec>\n",
715
+ "</datastory_chart>\n",
716
+ "\n",
717
+ "===== WORKFLOW =====\n",
718
+ "1. **START WITH THINKING**: Always begin with <thinking> tags explaining your analytical plan\n",
719
+ "2. **Execute SQL queries**: Use sql_tool(sql_query=\"...\", user_query=\"...\")\n",
720
+ "3. **WRITE ANSWER**: Present insights in <answer> tags with specific numbers\n",
721
+ "4. **CREATE CHARTS**: Generate 2-3 visualizations wrapped in <datastory_chart> tags\n",
722
+ "\n",
723
+ "===== SQL QUERY RULES =====\n",
724
+ "- Table names: CSV1, CSV2, CSV3, etc. (from dataset context above)\n",
725
+ "- Only SELECT and WITH statements allowed\n",
726
+ "- Always include user_query parameter in sql_tool calls\n",
727
+ "- Use meaningful column aliases\n",
728
+ "- Limit results appropriately for visualizations\n",
729
+ "\n",
730
+ "===== VISUALIZATION BEST PRACTICES =====\n",
731
+ "- Include ACTUAL data in \"values\" array (not placeholders)\n",
732
+ "- Use appropriate chart types: bar, line, point, area, arc (pie)\n",
733
+ "- Add tooltips: \"mark\": {{\"type\": \"bar\", \"tooltip\": true}}\n",
734
+ "- Use clear axis labels and titles\n",
735
+ "- Color schemes: use \"color\" encoding for categorical data\n",
736
+ "- For time series: parse dates with \"timeUnit\"\n",
737
+ "\n",
738
+ "===== QUALITY STANDARDS =====\n",
739
+ "✓ Be specific with numbers: \"Revenue increased by $2.3M (18.7%)\" \n",
740
+ "✓ Provide context: Compare to benchmarks, previous periods, or goals\n",
741
+ "✓ Explain WHY patterns exist, not just WHAT they are\n",
742
+ "✓ Structure with clear headers: ## Key Finding 1, ## Key Finding 2\n",
743
+ "✓ Generate 2-3 complementary visualizations\n",
744
+ "✓ End with actionable recommendations\n",
745
+ "\n",
746
+ "===== EXAMPLE RESPONSE STRUCTURE =====\n",
747
+ "\n",
748
+ "<thinking>\n",
749
+ "Let me analyze the sales data systematically:\n",
750
+ "\n",
751
+ "1. **Data Exploration**: I'll query CSV1 to understand sales distribution across regions and time periods\n",
752
+ "2. **Trend Analysis**: Calculate monthly growth rates to identify seasonal patterns \n",
753
+ "3. **Performance Ranking**: Identify top 10 products by revenue\n",
754
+ "4. **Visualization Plan**:\n",
755
+ " - Chart 1: Regional sales comparison (bar chart)\n",
756
+ " - Chart 2: Monthly trend with growth rate (line chart)\n",
757
+ " - Chart 3: Top 10 products (horizontal bar)\n",
758
+ "\n",
759
+ "SQL Strategy: First, aggregate sales by region, then by month, finally rank products.\n",
760
+ "</thinking>\n",
761
+ "\n",
762
+ "<answer>\n",
763
+ "## Executive Summary\n",
764
+ "Analysis of 50,000 transactions reveals strong regional disparities and seasonal trends...\n",
765
+ "\n",
766
+ "## Regional Performance\n",
767
+ "The West region dominates with $4.2M in sales (42% of total), followed by East at $3.1M (31%)...\n",
768
+ "\n",
769
+ "[Continue with detailed findings, specific numbers, and business implications]\n",
770
+ "</answer>\n",
771
+ "\n",
772
+ "<datastory_chart>\n",
773
+ "<chart_title>Sales by Region - West Region Leads with 42% Share</chart_title>\n",
774
+ "<chart_description>West region generated $4.2M, 35% ahead of second-place East region ($3.1M)</chart_description>\n",
775
+ "<vega_spec>\n",
776
+ "{{\n",
777
+ " \"$schema\": \"https://vega.github.io/schema/vega-lite/v5.json\",\n",
778
+ " \"title\": {{\"text\": \"Sales by Region ($ Millions)\", \"fontSize\": 18}},\n",
779
+ " \"width\": 700,\n",
780
+ " \"height\": 400,\n",
781
+ " \"data\": {{\"values\": [\n",
782
+ " {{\"region\": \"West\", \"sales\": 4.2}},\n",
783
+ " {{\"region\": \"East\", \"sales\": 3.1}},\n",
784
+ " {{\"region\": \"North\", \"sales\": 1.8}},\n",
785
+ " {{\"region\": \"South\", \"sales\": 1.5}}\n",
786
+ " ]}},\n",
787
+ " \"mark\": {{\"type\": \"bar\", \"tooltip\": true, \"color\": \"#4C78A8\"}},\n",
788
+ " \"encoding\": {{\n",
789
+ " \"x\": {{\"field\": \"region\", \"type\": \"nominal\", \"sort\": \"-y\"}},\n",
790
+ " \"y\": {{\"field\": \"sales\", \"type\": \"quantitative\", \"title\": \"Sales ($M)\"}}\n",
791
+ " }}\n",
792
+ "}}\n",
793
+ "</vega_spec>\n",
794
+ "</datastory_chart>\n",
795
+ "\n",
796
+ "REMEMBER: Always use the three-phase structure with XML tags. Your reasoning goes in <thinking>, analysis in <answer>, and charts in <datastory_chart> blocks.\n",
797
+ "\"\"\"\n",
798
+ " )\n",
799
+ "\n",
800
+ "# =====================================\n",
801
+ "# PHASE-AWARE CHART BUFFER\n",
802
+ "# =====================================\n",
803
+ "\n",
804
+ "class PhaseAwareBuffer:\n",
805
+ " \"\"\"Buffer that separates content into thinking, answer, and chart phases\"\"\"\n",
806
+ " \n",
807
+ " def __init__(self):\n",
808
+ " self.full_buffer = \"\"\n",
809
+ " self.thinking_complete = False\n",
810
+ " self.answer_complete = False\n",
811
+ " self.extracted_charts = []\n",
812
+ " \n",
813
+ " def add_text(self, text: str) -> List[Tuple[str, str]]:\n",
814
+ " \"\"\"\n",
815
+ " Add text and extract phase-separated content.\n",
816
+ " Returns: List of (phase, content) tuples\n",
817
+ " \"\"\"\n",
818
+ " self.full_buffer += text\n",
819
+ " emissions = []\n",
820
+ " \n",
821
+ " # Extract thinking phase\n",
822
+ " if not self.thinking_complete:\n",
823
+ " thinking_match = re.search(r'<thinking>(.*?)</thinking>', self.full_buffer, re.DOTALL)\n",
824
+ " if thinking_match:\n",
825
+ " thinking_content = thinking_match.group(1).strip()\n",
826
+ " emissions.append((\"thinking\", thinking_content))\n",
827
+ " self.thinking_complete = True\n",
828
+ " # Remove from buffer\n",
829
+ " self.full_buffer = self.full_buffer[thinking_match.end():]\n",
830
+ " \n",
831
+ " # Extract answer phase\n",
832
+ " if self.thinking_complete and not self.answer_complete:\n",
833
+ " answer_match = re.search(r'<answer>(.*?)</answer>', self.full_buffer, re.DOTALL)\n",
834
+ " if answer_match:\n",
835
+ " answer_content = answer_match.group(1).strip()\n",
836
+ " emissions.append((\"answer\", answer_content))\n",
837
+ " self.answer_complete = True\n",
838
+ " # Remove from buffer\n",
839
+ " self.full_buffer = self.full_buffer[answer_match.end():]\n",
840
+ " \n",
841
+ " # Extract charts (can be multiple)\n",
842
+ " if self.answer_complete:\n",
843
+ " chart_pattern = re.compile(r'<datastory_chart>(.*?)</datastory_chart>', re.DOTALL)\n",
844
+ " matches = list(chart_pattern.finditer(self.full_buffer))\n",
845
+ " \n",
846
+ " for match in matches:\n",
847
+ " chart_content = match.group(1)\n",
848
+ " \n",
849
+ " # Parse Vega spec\n",
850
+ " vega_pattern = r'<vega_spec>(.*?)</vega_spec>'\n",
851
+ " vega_match = re.search(vega_pattern, chart_content, re.DOTALL)\n",
852
+ " \n",
853
+ " if vega_match:\n",
854
+ " try:\n",
855
+ " vega_json_str = vega_match.group(1).strip()\n",
856
+ " vega_spec = json.loads(vega_json_str)\n",
857
+ " \n",
858
+ " # Only emit if not already extracted\n",
859
+ " if vega_spec not in self.extracted_charts:\n",
860
+ " self.extracted_charts.append(vega_spec)\n",
861
+ " emissions.append((\"charts\", vega_spec))\n",
862
+ " except json.JSONDecodeError as e:\n",
863
+ " print(f\"Failed to parse chart JSON: {e}\")\n",
864
+ " \n",
865
+ " # Remove from buffer\n",
866
+ " self.full_buffer = self.full_buffer[match.end():]\n",
867
+ " \n",
868
+ " return emissions\n",
869
+ "\n",
870
+ "# =====================================\n",
871
+ "# PHASED STREAMING WITH AGENT REASONING\n",
872
+ "# =====================================\n",
873
+ "\n",
874
+ "async def stream_analysis(query: str, file_paths: List[str]) -> AsyncGenerator[Tuple[str, str], None]:\n",
875
+ " \"\"\"\n",
876
+ " Stream analysis with agent's actual reasoning in thinking phase.\n",
877
+ " Yields: (phase, content) tuples where phase is:\n",
878
+ " - THINKING: Agent's reasoning process\n",
879
+ " - ANSWER: Final analysis and narrative\n",
880
+ " - CHART: Vega-Lite chart specification\n",
881
+ " - ERROR, DONE: Control signals\n",
882
+ " \"\"\"\n",
883
+ " print(f\"\\n{'='*70}\")\n",
884
+ " print(\"ANALYSIS REQUEST (phased streaming with agent reasoning)\")\n",
885
+ " print(f\"Files: {', '.join(file_paths)}\")\n",
886
+ " print(f\"Query: {query}\")\n",
887
+ " print(f\"{'='*70}\\n\")\n",
888
+ "\n",
889
+ " # Load datasets\n",
890
+ " try:\n",
891
+ " dataset_context = load_and_register_datasets(file_paths)\n",
892
+ " print(f\"\\nDataset Context loaded: {len(dataset_context)} chars\\n\")\n",
893
+ " except Exception as e:\n",
894
+ " yield (\"ERROR\", f\"Failed to load datasets: {str(e)}\")\n",
895
+ " yield (\"DONE\", \"\")\n",
896
+ " return\n",
897
+ "\n",
898
+ " # Create agent\n",
899
+ " agent = create_data_analyst_agent(dataset_context)\n",
900
+ " \n",
901
+ " # Initialize phase-aware buffer\n",
902
+ " phase_buffer = PhaseAwareBuffer()\n",
903
+ " \n",
904
+ " try:\n",
905
+ " result = Runner.run_streamed(agent, input=query)\n",
906
+ " \n",
907
+ " async for event in result.stream_events():\n",
908
+ " if (\n",
909
+ " event.type == \"raw_response_event\"\n",
910
+ " and isinstance(event.data, ResponseTextDeltaEvent)\n",
911
+ " ):\n",
912
+ " delta: str = event.data.delta\n",
913
+ " \n",
914
+ " # Add to buffer and extract phase-separated content\n",
915
+ " emissions = phase_buffer.add_text(delta)\n",
916
+ " \n",
917
+ " for phase, content in emissions:\n",
918
+ " if phase == \"thinking\":\n",
919
+ " # Stream thinking in chunks for better UX\n",
920
+ " sentences = content.split('\\n')\n",
921
+ " for sentence in sentences:\n",
922
+ " if sentence.strip():\n",
923
+ " yield (\"THINKING\", sentence.strip() + \"\\n\")\n",
924
+ " await asyncio.sleep(0.05)\n",
925
+ " \n",
926
+ " elif phase == \"answer\":\n",
927
+ " # Stream answer in chunks\n",
928
+ " paragraphs = content.split('\\n\\n')\n",
929
+ " for para in paragraphs:\n",
930
+ " if para.strip():\n",
931
+ " yield (\"ANSWER\", para.strip() + \"\\n\\n\")\n",
932
+ " await asyncio.sleep(0.05)\n",
933
+ " \n",
934
+ " elif phase == \"charts\":\n",
935
+ " # Emit complete chart\n",
936
+ " yield (\"CHART\", content)\n",
937
+ " await asyncio.sleep(0.1)\n",
938
+ "\n",
939
+ " except Exception as e:\n",
940
+ " yield (\"ERROR\", f\"Streaming error: {str(e)}\")\n",
941
+ " \n",
942
+ " yield (\"DONE\", \"\")\n",
943
+ "\n",
944
+ "# =====================================\n",
945
+ "# LEGACY STREAMING (BACKWARD COMPATIBILITY)\n",
946
+ "# =====================================\n",
947
+ "\n",
948
+ "async def stream_analysis_legacy(query: str, file_paths: List[str]) -> AsyncGenerator[str, None]:\n",
949
+ " \"\"\"\n",
950
+ " Original streaming without phases - for backward compatibility\n",
951
+ " \"\"\"\n",
952
+ " dataset_context = load_and_register_datasets(file_paths)\n",
953
+ " agent = create_data_analyst_agent(dataset_context)\n",
954
+ " result = Runner.run_streamed(agent, input=query)\n",
955
+ " \n",
956
+ " async for event in result.stream_events():\n",
957
+ " if (\n",
958
+ " event.type == \"raw_response_event\"\n",
959
+ " and isinstance(event.data, ResponseTextDeltaEvent)\n",
960
+ " ):\n",
961
+ " yield event.data.delta\n",
962
+ " \n",
963
+ " yield \"[DONE]\""
964
+ ]
965
+ }
966
+ ],
967
+ "metadata": {
968
+ "language_info": {
969
+ "name": "python"
970
+ }
971
+ },
972
+ "nbformat": 4,
973
+ "nbformat_minor": 5
974
+ }
report_generation/__pycache__/pdf_v3.cpython-311.pyc CHANGED
Binary files a/report_generation/__pycache__/pdf_v3.cpython-311.pyc and b/report_generation/__pycache__/pdf_v3.cpython-311.pyc differ
 
report_generation/pdf_v3.py CHANGED
@@ -1,394 +1,394 @@
1
-
2
- from fastapi import FastAPI, HTTPException
3
- from fastapi.responses import HTMLResponse
4
- from pydantic import BaseModel
5
- from typing import List, Dict
6
- import os
7
- import json
8
- import httpx
9
- import asyncio
10
- from datetime import datetime
11
- import google.generativeai as genai
12
- from fastapi import APIRouter
13
-
14
- import base64
15
- import altair as alt
16
- import re
17
-
18
- # #app = FastAPI(title="Gemini Report Generator API")
19
- # Report_Generation_Router_ppt_v2 = APIRouter(prefix="/report_generation_ppt_v2", tags=["Report_Generation_ppt_v2"])
20
- # # ================= CONFIG =================
21
- # GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyBWACJwKQVwEwACVcoVANgYOXXinwuPNFw")
22
- # genai.configure(api_key=GEMINI_API_KEY)
23
-
24
-
25
- # === CONFIG ===
26
- genai.configure(api_key="AIzaSyBWACJwKQVwEwACVcoVANgYOXXinwuPNFw")
27
- MODEL_NAME = "gemini-2.5-flash" # CORRECT MODEL
28
-
29
- # app = FastAPI(
30
- # title="AI Stroke Report Generator",
31
- # description="Receive user input → Extract Q&A → Generate 9-page HTML report",
32
- # version="1.0"
33
- # )
34
- Report_Generation_Router_pdf_v2 = APIRouter(prefix="/report_generation_pdf_v2", tags=["Report_Generation_pdf_v2"])
35
- # Dynamic model: Get current valid one (fixes 404; adapts to 2025+)
36
- def get_valid_model():
37
- try:
38
- models = [m.name for m in genai.list_models() if 'generate_content' in m.supported_generation_methods]
39
- flash_model = next((m for m in models if 'gemini-2.0-flash' in m or 'gemini-2.5-flash' in m), models[0] if models else 'gemini-2.0-flash')
40
- print(f"✅ Using model: {flash_model}")
41
- return genai.GenerativeModel(flash_model)
42
- except Exception as e:
43
- print(f"⚠️ Model fetch failed ({e}); fallback to gemini-2.0-flash")
44
- return genai.GenerativeModel('gemini-2.0-flash')
45
-
46
- MODEL = get_valid_model()
47
-
48
- OUTPUT_DIR = "pdf_reports"
49
- CHARTS_DIR = "temp_charts"
50
- os.makedirs(OUTPUT_DIR, exist_ok=True)
51
- os.makedirs(CHARTS_DIR, exist_ok=True)
52
-
53
- # Themes (minimal)
54
- THEMES = {
55
- "execBlue": {"name": "Executive Blue", "h": "#1B56A6", "a": "#06B6D4", "g1": "#1B56A6", "g2": "#06B6D4"},
56
- "sunsetOrange": {"name": "Sunset Orange", "h": "#FF7A1A", "a": "#F59E0B", "g1": "#FF7A1A", "g2": "#F59E0B"},
57
- "indigoExec": {"name": "Indigo Executive", "h": "#4C1D95", "a": "#A78BFA", "g1": "#4C1D95", "g2": "#A78BFA"},
58
- "healthTeal": {"name": "Healthcare Teal", "h": "#0F766E", "a": "#14B8A6", "g1": "#0F766E", "g2": "#14B8A6"},
59
- }
60
- DEFAULT_THEME = "execBlue"
61
-
62
- # ================= MODELS =================
63
- class ReportRequest(BaseModel):
64
- format_type: str
65
- reportname: str
66
- include_citations: bool
67
- user_id: str
68
- success: bool
69
- list_of_queries: List[str]
70
- theme: str = "execBlue"
71
-
72
- # Description is now mandatory for prompt generation.
73
- # If the user does not provide it, this default text will be used.
74
- Description: str = "Please ensure the report provides a comprehensive analysis of the data, highlighting key findings and actionable recommendations."
75
-
76
- # ================= HELPERS =================
77
- async def fetch_convo(client: httpx.AsyncClient, user_id: str, convo_id: str):
78
- url = f"https://mr-mvp-api-dev.dev.ingenspark.com/Db_store_router/conversations/{convo_id}"
79
- try:
80
- r = await client.get(url, params={"user_id": user_id}, timeout=30)
81
- r.raise_for_status()
82
- data = r.json().get("conversation") or (r.json().get("conversations") or [None])[0]
83
- if not data: return None
84
- resp = data.get("response", {})
85
- return {"query": data.get("user_query", {}).get("text", ""), "response": resp.get("text", "")}
86
- except Exception as e:
87
- print(f"Error fetching conversation {convo_id}: {e}")
88
- return None
89
-
90
- async def fetch_all(user_id: str, convo_ids: List[str]):
91
- async with httpx.AsyncClient() as client:
92
- tasks = [fetch_convo(client, user_id, cid) for cid in convo_ids]
93
- results = await asyncio.gather(*tasks)
94
- return [r for r in results if r]
95
-
96
- def render_vega(spec: dict) -> str | None:
97
- try:
98
- chart = alt.Chart.from_dict(spec)
99
- path = os.path.join(CHARTS_DIR, f"tmp_{datetime.now().timestamp()}.png")
100
- chart.save(path)
101
- with open(path, "rb") as f:
102
- b64 = base64.b64encode(f.read()).decode()
103
- os.remove(path)
104
- return f"data:image/png;base64,{b64}"
105
- except Exception as e:
106
- print(f"Error rendering Vega chart: {e}")
107
- return None
108
-
109
- # ================= GEMINI (All via one prompt) =================
110
- def ask_gemini(system: str, user: str) -> str:
111
- try:
112
- resp = MODEL.generate_content([system, user], generation_config={"temperature": 0.2, "max_output_tokens": 8192})
113
- return resp.text.strip()
114
- except Exception as e:
115
- print(f"Gemini error: {e}")
116
- return '{"error": "Gemini failed"}'
117
-
118
- def generate_report_content(conversations: List[dict], title: str, theme_id: str, description: str):
119
- # Get the theme configuration, fallback to default if theme_id is invalid
120
- theme_cfg = THEMES.get(theme_id, THEMES[DEFAULT_THEME])
121
-
122
- # Always add the description to the prompt, as it's now guaranteed to have a value.
123
- description_text = f"\nUser's specific instructions/description: {description}\n"
124
-
125
- prompt = f"""
126
- You are an expert analyst. Generate a full professional report titled "{title}" from these conversations.{description_text}
127
-
128
- Data:
129
- {json.dumps([{'query': c['query'], 'response': c['response']} for c in conversations], indent=2)}
130
-
131
- Do ALL: Analyze data, generate overview/summary/TOC, clean sections (remove narration/tags/suggestions, convert markdown to HTML).
132
- If the user description mentions specific sections to *focus on* (e.g., 'only graphs and overview'), provide rich and detailed content for those sections. For sections implicitly or explicitly excluded, provide minimal, generic, or empty content but ensure the JSON structure remains valid.
133
-
134
- Output ONLY valid JSON (no extra text/markdown):
135
- {{
136
- "overview": "3-4 specific paragraphs (use real numbers/categories from data; formal tone)",
137
- "summary": "3-4 paragraphs with actionable recommendations (specific, data-driven)",
138
- "toc": [{{"title": "string", "page": int}}], // Start pages at 4
139
- "sections": [ // One per conversation
140
- {{
141
- "title": "Extracted chart title or query summary",
142
- "content_html": "Cleaned <p>text</p><strong>bold</strong><ul><li>bullets</li></ul> (no AI chit-chat)",
143
- "chart": {{"title": "string or null", "vega_spec": {{...}} or null, "caption": "string or null"}}
144
- }}
145
- ]
146
- }}
147
-
148
- Rules: Specific (e.g., '68% in Engineering'); clean aggressively; if no vega_spec, use null.
149
- """
150
-
151
- raw = ask_gemini("Precise JSON generator only. No explanations.", prompt)
152
-
153
- # Robust JSON parse
154
- try:
155
- start = raw.find("{")
156
- end = raw.rfind("}") + 1
157
- if start == -1: raise ValueError("No JSON")
158
- data = json.loads(raw[start:end])
159
- required = ["overview", "summary", "toc", "sections"]
160
- if not all(k in data for k in required):
161
- raise ValueError("Missing keys")
162
- return data
163
- except Exception as e:
164
- print(f"JSON failed: {e}\nRaw: {raw[:1000]}")
165
- # Return a fallback structure with minimal content to prevent further errors
166
- return {
167
- "overview": "Analysis overview based on provided data.",
168
- "summary": "Key findings and recommendations.",
169
- "toc": [{"title": "Overview", "page": 4}],
170
- "sections": [{"title": "Section 1", "content_html": "<p>Sample content.</p>", "chart": {"title": None, "vega_spec": None, "caption": None}}]
171
- }
172
-
173
- # ================= HTML TEMPLATE (Unchanged Structure) =================
174
- HTML_TEMPLATE = """<!DOCTYPE html>
175
- <html><head><meta charset="UTF-8"><title>{title}</title>
176
- <style>
177
- body {{font-family:Georgia,serif;margin:0;background:#fff;color:#374151;line-height:1.6;}}
178
- .page {{width:210mm;min-height:297mm;padding:20mm;background:white;box-shadow:0 4px 8px rgba(0,0,0,0.1);margin:20px auto;position:relative;}}
179
- .cover-page {{background:linear-gradient(135deg,{g1} 0%,{g2} 100%);color:white;text-align:center;display:flex;justify-content:center;align-items:center;flex-direction:column;}}
180
- .cover-title {{font-size:36px;font-weight:bold;margin-bottom:20px;text-transform:uppercase;letter-spacing:2px;}}
181
- .cover-meta {{font-size:16px;margin-top:60px;padding-top:40px;border-top:2px solid rgba(255,255,255,0.3);}}
182
- .section-title {{font-size:28px;color:{h};margin-bottom:30px;padding-bottom:15px;border-bottom:3px solid {a};}}
183
- .chart-container {{margin:30px 0;text-align:center;}}
184
- .chart-container img {{max-width:100%;height:auto;border:1px solid #ddd;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.1);}}
185
- .chart-caption {{margin-top:15px;font-style:italic;color:#555;font-size:14px;}}
186
- .toc-list {{list-style:none;padding:0;margin-top:40px;}}
187
- .toc-item {{display:flex;justify-content:space-between;padding:15px 0;border-bottom:1px dotted #bdc3c7;font-size:18px;}}
188
- .page-footer {{position:absolute;bottom:15mm;left:20mm;right:20mm;display:flex;justify-content:space-between;align-items:center;font-size:12px;color:#7f8c8d;border-top:1px solid #ecf0f1;padding-top:10px;}}
189
- @media print {{body {{background:white;}}.page {{box-shadow:none;margin:0;page-break-after:always;}}}}
190
- </style></head><body><div class="report-container" style="display:flex;flex-direction:column;gap:20px;padding:20px;max-width:1200px;margin:0 auto;">{body}</div></body></html>"""
191
-
192
- # ================= ENDPOINT =================
193
- @Report_Generation_Router_pdf_v2.post("/generate_report")
194
- async def generate_report(req: ReportRequest):
195
- try:
196
- print(f"🚀 Report: {req.reportname} | User: {req.user_id} | Convos: {len(req.list_of_queries)} | Citations: {req.include_citations} | Success: {req.success} | Description: {req.Description}")
197
-
198
- data = await fetch_all(req.user_id, req.list_of_queries)
199
- if not data:
200
- raise HTTPException(404, "No valid conversations found")
201
-
202
- print(f"✅ Fetched {len(data)} conversations")
203
-
204
- # Normalize the theme from the request here, before passing it around
205
- normalized_theme_id = req.theme
206
- if "orang" in req.theme.lower():
207
- normalized_theme_id = "sunsetOrange"
208
- elif normalized_theme_id not in THEMES:
209
- normalized_theme_id = DEFAULT_THEME
210
-
211
- print(f"🎨 Applying theme: {normalized_theme_id}")
212
-
213
- # Pass the normalized theme_id AND the description to generate_report_content
214
- content = generate_report_content(data, req.reportname, normalized_theme_id, req.Description)
215
-
216
- # Use the normalized theme for HTML formatting
217
- selected_theme_config = THEMES.get(normalized_theme_id, THEMES[DEFAULT_THEME])
218
-
219
- # Render charts
220
- for sec in content.get("sections", []):
221
- # Ensure 'sec' is a dictionary before attempting to access its keys
222
- if not isinstance(sec, dict):
223
- print(f"⚠️ Warning: Non-dictionary item found in 'sections' array, skipping: {sec}")
224
- continue # Skip to the next item if it's not a dictionary
225
-
226
- # Retrieve the 'chart' data. If 'chart' key is missing or its value is not a dict (e.g., null),
227
- # default it to an empty dictionary to safely call .get() on it.
228
- chart_config = sec.get("chart")
229
- if not isinstance(chart_config, dict):
230
- chart_config = {} # Ensure it's always a dict for safe .get() calls
231
-
232
- if chart_config.get("vega_spec"):
233
- b64 = render_vega(chart_config["vega_spec"])
234
- chart_config["image_b64"] = b64 if b64 else ""
235
- sec["chart"] = chart_config # Update the 'chart' entry in 'sec' with the modified config
236
-
237
- # --- NEW: Parse description to determine what sections to include ---
238
- description_lower = req.Description.lower()
239
-
240
- # Default inclusions (before explicit "no need of" or "only" directives)
241
- include_coverpage = True
242
- include_toc = True
243
- include_overview = True
244
- include_sections = True
245
- include_summary = True
246
-
247
- # Check for explicit exclusions
248
- if "no need of coverpage" in description_lower or "no coverpage" in description_lower:
249
- include_coverpage = False
250
- if "no need of toc" in description_lower or "no table of contents" in description_lower:
251
- include_toc = False
252
- if "no need of overview" in description_lower or "no executive overview" in description_lower:
253
- include_overview = False
254
- if "no need of graphs" in description_lower or "no need of sections" in description_lower: # "graphs" implies sections for charts
255
- include_sections = False
256
- if "no need of summary" in description_lower or "no need of conclusion" in description_lower:
257
- include_summary = False
258
-
259
- # If "only" is present, it overrides previous inclusions to keep only specified parts.
260
- # This uses a regex to find "only " followed by a keyword, allowing for more flexible phrasing.
261
- only_match = re.search(r"only\s+(.*?)(?:in report|\.|$)", description_lower)
262
- if only_match:
263
- # Set all to False first, then selectively enable based on "only" keywords found
264
- include_coverpage = False
265
- include_toc = False
266
- include_overview = False
267
- include_sections = False
268
- include_summary = False
269
-
270
- only_keywords = only_match.group(1)
271
-
272
- if "coverpage" in only_keywords:
273
- include_coverpage = True
274
- if "toc" in only_keywords or "table of contents" in only_keywords:
275
- include_toc = True
276
- if "overview" in only_keywords or "executive overview" in only_keywords:
277
- include_overview = True
278
- if "graphs" in only_keywords or "sections" in only_keywords: # "graphs" implies sections for charts
279
- include_sections = True
280
- if "summary" in only_keywords or "conclusion" in only_keywords:
281
- include_summary = True
282
-
283
-
284
- # Build pages conditionally
285
- body_parts = []
286
- current_page_number = 1 # Start page numbering from 1
287
-
288
- # Cover
289
- if include_coverpage:
290
- current_date = datetime.now().strftime('%B %Y')
291
- body_parts.append(f'''
292
- <div class="page cover-page">
293
- <h1 class="cover-title">{req.reportname}</h1>
294
- <p class="cover-subtitle">Comprehensive Data Analysis Report</p>
295
- <div class="cover-meta">
296
- <div class="cover-company">ACCUSAGA</div>
297
- <div>{current_date}</div>
298
- </div>
299
- </div>''')
300
- current_page_number += 1
301
-
302
- # TOC
303
- if include_toc:
304
- # Note: TOC page numbers in content["toc"] might not align with actual rendered page numbers.
305
- # This is a limitation when dynamically building HTML pages and relying on fixed page numbers from LLM.
306
- toc_html = ''.join(f'<li class="toc-item"><span>{item["title"]}</span><span>Page {item["page"]}</span></li>' for item in content.get("toc", []))
307
- body_parts.append(f'''
308
- <div class="page">
309
- <h2 class="section-title">Table of Contents</h2>
310
- <ul class="toc-list">{toc_html}</ul>
311
- <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
312
- </div>''')
313
- current_page_number += 1
314
-
315
- # Overview
316
- if include_overview:
317
- body_parts.append(f'''
318
- <div class="page">
319
- <h2 class="section-title">Executive Overview</h2>
320
- <div class="content-area" style="margin-top:20px;line-height:1.8;">{content["overview"]}</div>
321
- <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
322
- </div>''')
323
- current_page_number += 1
324
-
325
- # Sections
326
- if include_sections:
327
- for sec in content.get("sections", []):
328
- # Ensure 'sec' is a dictionary before processing
329
- if not isinstance(sec, dict):
330
- continue
331
-
332
- chart_html = ''
333
- # Retrieve chart configuration robustly
334
- ch = sec.get("chart")
335
- if not isinstance(ch, dict):
336
- ch = {} # Ensure ch is a dictionary for safe access
337
-
338
- if ch.get("image_b64"):
339
- chart_html = f'''
340
- <div class="chart-container">
341
- <h3>{ch.get("title", "Visualization")}</h3>
342
- <img src="{ch["image_b64"]}" alt="{ch.get("title")}">
343
- <p class="chart-caption">{ch.get("caption", "")}</p>
344
- </div>'''
345
- body_parts.append(f'''
346
- <div class="page">
347
- <h2 class="section-title">{sec["title"]}</h2>
348
- <div class="content-area" style="margin-top:20px;line-height:1.8;">{sec["content_html"]}{chart_html}</div>
349
- <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
350
- </div>''')
351
- current_page_number += 1
352
-
353
- # Summary
354
- if include_summary:
355
- body_parts.append(f'''
356
- <div class="page">
357
- <h2 class="section-title">Summary and Conclusions</h2>
358
- <div class="content-area" style="margin-top:20px;line-height:1.8;">{content["summary"]}</div>
359
- <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
360
- </div>''')
361
- current_page_number += 1
362
-
363
- # If no parts are included, return a minimal HTML response
364
- if not body_parts:
365
- return HTMLResponse(content="<html><head><title>No Report Content</title></head><body><h1>No report content generated or included based on your description.</h1></body></html>")
366
-
367
- html = HTML_TEMPLATE.format(title=req.reportname,
368
- g1=selected_theme_config["g1"],
369
- g2=selected_theme_config["g2"],
370
- h=selected_theme_config["h"],
371
- a=selected_theme_config["a"],
372
- body=''.join(body_parts))
373
-
374
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
375
- filename = f"report_{req.user_id}_{timestamp}.html"
376
- path = os.path.join(OUTPUT_DIR, filename)
377
- with open(path, 'w', encoding='utf-8') as f:
378
- f.write(html)
379
-
380
- print(f"✅ Report saved: {path}")
381
- return HTMLResponse(content=html)
382
-
383
- except HTTPException:
384
- raise
385
- except Exception as e:
386
- print(f"❌ Error: {e}")
387
- import traceback
388
- traceback.print_exc()
389
- raise HTTPException(500, str(e))
390
-
391
- @Report_Generation_Router_pdf_v2.get("/")
392
- async def root():
393
- return {"message": "Gemini Report Generator API v3 (Fixed for 2025)", "model": MODEL._model_name, "themes": list(THEMES.keys())}
394
-
 
1
+
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.responses import HTMLResponse
4
+ from pydantic import BaseModel
5
+ from typing import List, Dict
6
+ import os
7
+ import json
8
+ import httpx
9
+ import asyncio
10
+ from datetime import datetime
11
+ import google.generativeai as genai
12
+ from fastapi import APIRouter
13
+
14
+ import base64
15
+ import altair as alt
16
+ import re
17
+
18
+ # #app = FastAPI(title="Gemini Report Generator API")
19
+ # Report_Generation_Router_ppt_v2 = APIRouter(prefix="/report_generation_ppt_v2", tags=["Report_Generation_ppt_v2"])
20
+ # # ================= CONFIG =================
21
+ # GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyBWACJwKQVwEwACVcoVANgYOXXinwuPNFw")
22
+ # genai.configure(api_key=GEMINI_API_KEY)
23
+
24
+
25
+ # === CONFIG ===
26
+ genai.configure(api_key="AIzaSyBWACJwKQVwEwACVcoVANgYOXXinwuPNFw")
27
+ MODEL_NAME = "gemini-2.5-flash" # CORRECT MODEL
28
+
29
+ # app = FastAPI(
30
+ # title="AI Stroke Report Generator",
31
+ # description="Receive user input → Extract Q&A → Generate 9-page HTML report",
32
+ # version="1.0"
33
+ # )
34
+ Report_Generation_Router_pdf_v2 = APIRouter(prefix="/report_generation_pdf_v2", tags=["Report_Generation_pdf_v2"])
35
+ # Dynamic model: Get current valid one (fixes 404; adapts to 2025+)
36
+ def get_valid_model():
37
+ try:
38
+ models = [m.name for m in genai.list_models() if 'generate_content' in m.supported_generation_methods]
39
+ flash_model = next((m for m in models if 'gemini-2.0-flash' in m or 'gemini-2.5-flash' in m), models[0] if models else 'gemini-2.0-flash')
40
+ print(f"✅ Using model: {flash_model}")
41
+ return genai.GenerativeModel(flash_model)
42
+ except Exception as e:
43
+ print(f"⚠️ Model fetch failed ({e}); fallback to gemini-2.0-flash")
44
+ return genai.GenerativeModel('gemini-2.0-flash')
45
+
46
+ MODEL = get_valid_model()
47
+
48
+ OUTPUT_DIR = "pdf_reports"
49
+ CHARTS_DIR = "temp_charts"
50
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
51
+ os.makedirs(CHARTS_DIR, exist_ok=True)
52
+
53
+ # Themes (minimal)
54
+ THEMES = {
55
+ "execBlue": {"name": "Executive Blue", "h": "#1B56A6", "a": "#06B6D4", "g1": "#1B56A6", "g2": "#06B6D4"},
56
+ "sunsetOrange": {"name": "Sunset Orange", "h": "#FF7A1A", "a": "#F59E0B", "g1": "#FF7A1A", "g2": "#F59E0B"},
57
+ "indigoExec": {"name": "Indigo Executive", "h": "#4C1D95", "a": "#A78BFA", "g1": "#4C1D95", "g2": "#A78BFA"},
58
+ "healthTeal": {"name": "Healthcare Teal", "h": "#0F766E", "a": "#14B8A6", "g1": "#0F766E", "g2": "#14B8A6"},
59
+ }
60
+ DEFAULT_THEME = "execBlue"
61
+
62
+ # ================= MODELS =================
63
+ class ReportRequest(BaseModel):
64
+ format_type: str
65
+ reportname: str
66
+ include_citations: bool
67
+ user_id: str
68
+ success: bool
69
+ list_of_queries: List[str]
70
+ theme: str = "execBlue"
71
+
72
+ # Description is now mandatory for prompt generation.
73
+ # If the user does not provide it, this default text will be used.
74
+ Description: str = "Please ensure the report provides a comprehensive analysis of the data, highlighting key findings and actionable recommendations."
75
+
76
+ # ================= HELPERS =================
77
+ async def fetch_convo(client: httpx.AsyncClient, user_id: str, convo_id: str):
78
+ url = f"https://mr-mvp-api-dev.dev.ingenspark.com/Db_store_router/conversations/{convo_id}"
79
+ try:
80
+ r = await client.get(url, params={"user_id": user_id}, timeout=30)
81
+ r.raise_for_status()
82
+ data = r.json().get("conversation") or (r.json().get("conversations") or [None])[0]
83
+ if not data: return None
84
+ resp = data.get("response", {})
85
+ return {"query": data.get("user_query", {}).get("text", ""), "response": resp.get("text", "")}
86
+ except Exception as e:
87
+ print(f"Error fetching conversation {convo_id}: {e}")
88
+ return None
89
+
90
+ async def fetch_all(user_id: str, convo_ids: List[str]):
91
+ async with httpx.AsyncClient() as client:
92
+ tasks = [fetch_convo(client, user_id, cid) for cid in convo_ids]
93
+ results = await asyncio.gather(*tasks)
94
+ return [r for r in results if r]
95
+
96
+ def render_vega(spec: dict) -> str | None:
97
+ try:
98
+ chart = alt.Chart.from_dict(spec)
99
+ path = os.path.join(CHARTS_DIR, f"tmp_{datetime.now().timestamp()}.png")
100
+ chart.save(path)
101
+ with open(path, "rb") as f:
102
+ b64 = base64.b64encode(f.read()).decode()
103
+ os.remove(path)
104
+ return f"data:image/png;base64,{b64}"
105
+ except Exception as e:
106
+ print(f"Error rendering Vega chart: {e}")
107
+ return None
108
+
109
+ # ================= GEMINI (All via one prompt) =================
110
+ def ask_gemini(system: str, user: str) -> str:
111
+ try:
112
+ resp = MODEL.generate_content([system, user], generation_config={"temperature": 0.2, "max_output_tokens": 8192})
113
+ return resp.text.strip()
114
+ except Exception as e:
115
+ print(f"Gemini error: {e}")
116
+ return '{"error": "Gemini failed"}'
117
+
118
+ def generate_report_content(conversations: List[dict], title: str, theme_id: str, description: str):
119
+ # Get the theme configuration, fallback to default if theme_id is invalid
120
+ theme_cfg = THEMES.get(theme_id, THEMES[DEFAULT_THEME])
121
+
122
+ # Always add the description to the prompt, as it's now guaranteed to have a value.
123
+ description_text = f"\nUser's specific instructions/description: {description}\n"
124
+
125
+ prompt = f"""
126
+ You are an expert analyst. Generate a full professional report titled "{title}" from these conversations.{description_text}
127
+
128
+ Data:
129
+ {json.dumps([{'query': c['query'], 'response': c['response']} for c in conversations], indent=2)}
130
+
131
+ Do ALL: Analyze data, generate overview/summary/TOC, clean sections (remove narration/tags/suggestions, convert markdown to HTML).
132
+ If the user description mentions specific sections to *focus on* (e.g., 'only graphs and overview'), provide rich and detailed content for those sections. For sections implicitly or explicitly excluded, provide minimal, generic, or empty content but ensure the JSON structure remains valid.
133
+
134
+ Output ONLY valid JSON (no extra text/markdown):
135
+ {{
136
+ "overview": "3-4 specific paragraphs (use real numbers/categories from data; formal tone)",
137
+ "summary": "3-4 paragraphs with actionable recommendations (specific, data-driven)",
138
+ "toc": [{{"title": "string", "page": int}}], // Start pages at 4
139
+ "sections": [ // One per conversation
140
+ {{
141
+ "title": "Extracted chart title or query summary",
142
+ "content_html": "Cleaned <p>text</p><strong>bold</strong><ul><li>bullets</li></ul> (no AI chit-chat)",
143
+ "chart": {{"title": "string or null", "vega_spec": {{...}} or null, "caption": "string or null"}}
144
+ }}
145
+ ]
146
+ }}
147
+
148
+ Rules: Specific (e.g., '68% in Engineering'); clean aggressively; if no vega_spec, use null.
149
+ """
150
+
151
+ raw = ask_gemini("Precise JSON generator only. No explanations.", prompt)
152
+
153
+ # Robust JSON parse
154
+ try:
155
+ start = raw.find("{")
156
+ end = raw.rfind("}") + 1
157
+ if start == -1: raise ValueError("No JSON")
158
+ data = json.loads(raw[start:end])
159
+ required = ["overview", "summary", "toc", "sections"]
160
+ if not all(k in data for k in required):
161
+ raise ValueError("Missing keys")
162
+ return data
163
+ except Exception as e:
164
+ print(f"JSON failed: {e}\nRaw: {raw[:1000]}")
165
+ # Return a fallback structure with minimal content to prevent further errors
166
+ return {
167
+ "overview": "Analysis overview based on provided data.",
168
+ "summary": "Key findings and recommendations.",
169
+ "toc": [{"title": "Overview", "page": 4}],
170
+ "sections": [{"title": "Section 1", "content_html": "<p>Sample content.</p>", "chart": {"title": None, "vega_spec": None, "caption": None}}]
171
+ }
172
+
173
+ # ================= HTML TEMPLATE (Unchanged Structure) =================
174
+ HTML_TEMPLATE = """<!DOCTYPE html>
175
+ <html><head><meta charset="UTF-8"><title>{title}</title>
176
+ <style>
177
+ body {{font-family:Georgia,serif;margin:0;background:#fff;color:#374151;line-height:1.6;}}
178
+ .page {{width:210mm;min-height:297mm;padding:20mm;background:white;box-shadow:0 4px 8px rgba(0,0,0,0.1);margin:20px auto;position:relative;}}
179
+ .cover-page {{background:linear-gradient(135deg,{g1} 0%,{g2} 100%);color:white;text-align:center;display:flex;justify-content:center;align-items:center;flex-direction:column;}}
180
+ .cover-title {{font-size:36px;font-weight:bold;margin-bottom:20px;text-transform:uppercase;letter-spacing:2px;}}
181
+ .cover-meta {{font-size:16px;margin-top:60px;padding-top:40px;border-top:2px solid rgba(255,255,255,0.3);}}
182
+ .section-title {{font-size:28px;color:{h};margin-bottom:30px;padding-bottom:15px;border-bottom:3px solid {a};}}
183
+ .chart-container {{margin:30px 0;text-align:center;}}
184
+ .chart-container img {{max-width:100%;height:auto;border:1px solid #ddd;border-radius:8px;box-shadow:0 2px 4px rgba(0,0,0,0.1);}}
185
+ .chart-caption {{margin-top:15px;font-style:italic;color:#555;font-size:14px;}}
186
+ .toc-list {{list-style:none;padding:0;margin-top:40px;}}
187
+ .toc-item {{display:flex;justify-content:space-between;padding:15px 0;border-bottom:1px dotted #bdc3c7;font-size:18px;}}
188
+ .page-footer {{position:absolute;bottom:15mm;left:20mm;right:20mm;display:flex;justify-content:space-between;align-items:center;font-size:12px;color:#7f8c8d;border-top:1px solid #ecf0f1;padding-top:10px;}}
189
+ @media print {{body {{background:white;}}.page {{box-shadow:none;margin:0;page-break-after:always;}}}}
190
+ </style></head><body><div class="report-container" style="display:flex;flex-direction:column;gap:20px;padding:20px;max-width:1200px;margin:0 auto;">{body}</div></body></html>"""
191
+
192
+ # ================= ENDPOINT =================
193
+ @Report_Generation_Router_pdf_v2.post("/generate_report")
194
+ async def generate_report(req: ReportRequest):
195
+ try:
196
+ print(f"🚀 Report: {req.reportname} | User: {req.user_id} | Convos: {len(req.list_of_queries)} | Citations: {req.include_citations} | Success: {req.success} | Description: {req.Description}")
197
+
198
+ data = await fetch_all(req.user_id, req.list_of_queries)
199
+ if not data:
200
+ raise HTTPException(404, "No valid conversations found")
201
+
202
+ print(f"✅ Fetched {len(data)} conversations")
203
+
204
+ # Normalize the theme from the request here, before passing it around
205
+ normalized_theme_id = req.theme
206
+ if "orang" in req.theme.lower():
207
+ normalized_theme_id = "sunsetOrange"
208
+ elif normalized_theme_id not in THEMES:
209
+ normalized_theme_id = DEFAULT_THEME
210
+
211
+ print(f"🎨 Applying theme: {normalized_theme_id}")
212
+
213
+ # Pass the normalized theme_id AND the description to generate_report_content
214
+ content = generate_report_content(data, req.reportname, normalized_theme_id, req.Description)
215
+
216
+ # Use the normalized theme for HTML formatting
217
+ selected_theme_config = THEMES.get(normalized_theme_id, THEMES[DEFAULT_THEME])
218
+
219
+ # Render charts
220
+ for sec in content.get("sections", []):
221
+ # Ensure 'sec' is a dictionary before attempting to access its keys
222
+ if not isinstance(sec, dict):
223
+ print(f"⚠️ Warning: Non-dictionary item found in 'sections' array, skipping: {sec}")
224
+ continue # Skip to the next item if it's not a dictionary
225
+
226
+ # Retrieve the 'chart' data. If 'chart' key is missing or its value is not a dict (e.g., null),
227
+ # default it to an empty dictionary to safely call .get() on it.
228
+ chart_config = sec.get("chart")
229
+ if not isinstance(chart_config, dict):
230
+ chart_config = {} # Ensure it's always a dict for safe .get() calls
231
+
232
+ if chart_config.get("vega_spec"):
233
+ b64 = render_vega(chart_config["vega_spec"])
234
+ chart_config["image_b64"] = b64 if b64 else ""
235
+ sec["chart"] = chart_config # Update the 'chart' entry in 'sec' with the modified config
236
+
237
+ # --- NEW: Parse description to determine what sections to include ---
238
+ description_lower = req.Description.lower()
239
+
240
+ # Default inclusions (before explicit "no need of" or "only" directives)
241
+ include_coverpage = True
242
+ include_toc = True
243
+ include_overview = True
244
+ include_sections = True
245
+ include_summary = True
246
+
247
+ # Check for explicit exclusions
248
+ if "no need of coverpage" in description_lower or "no coverpage" in description_lower:
249
+ include_coverpage = False
250
+ if "no need of toc" in description_lower or "no table of contents" in description_lower:
251
+ include_toc = False
252
+ if "no need of overview" in description_lower or "no executive overview" in description_lower:
253
+ include_overview = False
254
+ if "no need of graphs" in description_lower or "no need of sections" in description_lower: # "graphs" implies sections for charts
255
+ include_sections = False
256
+ if "no need of summary" in description_lower or "no need of conclusion" in description_lower:
257
+ include_summary = False
258
+
259
+ # If "only" is present, it overrides previous inclusions to keep only specified parts.
260
+ # This uses a regex to find "only " followed by a keyword, allowing for more flexible phrasing.
261
+ only_match = re.search(r"only\s+(.*?)(?:in report|\.|$)", description_lower)
262
+ if only_match:
263
+ # Set all to False first, then selectively enable based on "only" keywords found
264
+ include_coverpage = False
265
+ include_toc = False
266
+ include_overview = False
267
+ include_sections = False
268
+ include_summary = False
269
+
270
+ only_keywords = only_match.group(1)
271
+
272
+ if "coverpage" in only_keywords:
273
+ include_coverpage = True
274
+ if "toc" in only_keywords or "table of contents" in only_keywords:
275
+ include_toc = True
276
+ if "overview" in only_keywords or "executive overview" in only_keywords:
277
+ include_overview = True
278
+ if "graphs" in only_keywords or "sections" in only_keywords: # "graphs" implies sections for charts
279
+ include_sections = True
280
+ if "summary" in only_keywords or "conclusion" in only_keywords:
281
+ include_summary = True
282
+
283
+
284
+ # Build pages conditionally
285
+ body_parts = []
286
+ current_page_number = 1 # Start page numbering from 1
287
+
288
+ # Cover
289
+ if include_coverpage:
290
+ current_date = datetime.now().strftime('%B %Y')
291
+ body_parts.append(f'''
292
+ <div class="page cover-page">
293
+ <h1 class="cover-title">{req.reportname}</h1>
294
+ <p class="cover-subtitle">Comprehensive Data Analysis Report</p>
295
+ <div class="cover-meta">
296
+ <div class="cover-company">ACCUSAGA</div>
297
+ <div>{current_date}</div>
298
+ </div>
299
+ </div>''')
300
+ current_page_number += 1
301
+
302
+ # TOC
303
+ if include_toc:
304
+ # Note: TOC page numbers in content["toc"] might not align with actual rendered page numbers.
305
+ # This is a limitation when dynamically building HTML pages and relying on fixed page numbers from LLM.
306
+ toc_html = ''.join(f'<li class="toc-item"><span>{item["title"]}</span><span>Page {item["page"]}</span></li>' for item in content.get("toc", []))
307
+ body_parts.append(f'''
308
+ <div class="page">
309
+ <h2 class="section-title">Table of Contents</h2>
310
+ <ul class="toc-list">{toc_html}</ul>
311
+ <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
312
+ </div>''')
313
+ current_page_number += 1
314
+
315
+ # Overview
316
+ if include_overview:
317
+ body_parts.append(f'''
318
+ <div class="page">
319
+ <h2 class="section-title">Executive Overview</h2>
320
+ <div class="content-area" style="margin-top:20px;line-height:1.8;">{content["overview"]}</div>
321
+ <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
322
+ </div>''')
323
+ current_page_number += 1
324
+
325
+ # Sections
326
+ if include_sections:
327
+ for sec in content.get("sections", []):
328
+ # Ensure 'sec' is a dictionary before processing
329
+ if not isinstance(sec, dict):
330
+ continue
331
+
332
+ chart_html = ''
333
+ # Retrieve chart configuration robustly
334
+ ch = sec.get("chart")
335
+ if not isinstance(ch, dict):
336
+ ch = {} # Ensure ch is a dictionary for safe access
337
+
338
+ if ch.get("image_b64"):
339
+ chart_html = f'''
340
+ <div class="chart-container">
341
+ <h3>{ch.get("title", "Visualization")}</h3>
342
+ <img src="{ch["image_b64"]}" alt="{ch.get("title")}">
343
+ <p class="chart-caption">{ch.get("caption", "")}</p>
344
+ </div>'''
345
+ body_parts.append(f'''
346
+ <div class="page">
347
+ <h2 class="section-title">{sec["title"]}</h2>
348
+ <div class="content-area" style="margin-top:20px;line-height:1.8;">{sec["content_html"]}{chart_html}</div>
349
+ <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
350
+ </div>''')
351
+ current_page_number += 1
352
+
353
+ # Summary
354
+ if include_summary:
355
+ body_parts.append(f'''
356
+ <div class="page">
357
+ <h2 class="section-title">Summary and Conclusions</h2>
358
+ <div class="content-area" style="margin-top:20px;line-height:1.8;">{content["summary"]}</div>
359
+ <div class="page-footer"><div>contact@accusaga.com</div><div>Page {current_page_number}</div></div>
360
+ </div>''')
361
+ current_page_number += 1
362
+
363
+ # If no parts are included, return a minimal HTML response
364
+ if not body_parts:
365
+ return HTMLResponse(content="<html><head><title>No Report Content</title></head><body><h1>No report content generated or included based on your description.</h1></body></html>")
366
+
367
+ html = HTML_TEMPLATE.format(title=req.reportname,
368
+ g1=selected_theme_config["g1"],
369
+ g2=selected_theme_config["g2"],
370
+ h=selected_theme_config["h"],
371
+ a=selected_theme_config["a"],
372
+ body=''.join(body_parts))
373
+
374
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
375
+ filename = f"report_{req.user_id}_{timestamp}.html"
376
+ path = os.path.join(OUTPUT_DIR, filename)
377
+ with open(path, 'w', encoding='utf-8') as f:
378
+ f.write(html)
379
+
380
+ print(f"✅ Report saved: {path}")
381
+ return HTMLResponse(content=html)
382
+
383
+ except HTTPException:
384
+ raise
385
+ except Exception as e:
386
+ print(f"❌ Error: {e}")
387
+ import traceback
388
+ traceback.print_exc()
389
+ raise HTTPException(500, str(e))
390
+
391
+ @Report_Generation_Router_pdf_v2.get("/")
392
+ async def root():
393
+ return {"message": "Gemini Report Generator API v3 (Fixed for 2025)", "model": MODEL._model_name, "themes": list(THEMES.keys())}
394
+
requirements.txt CHANGED
@@ -1,91 +1,91 @@
1
- fastapi
2
- uvicorn
3
- pandas
4
- sqlparse
5
- duckdb
6
- ipykernel
7
- numpy
8
- openai
9
- holoviews
10
- nbformat
11
- nbclient
12
- panel
13
- bokeh
14
- plotly
15
- altair
16
- redis
17
- IPython
18
- langchain_experimental
19
- google-generativeai
20
- tabulate
21
- redis_client
22
- langchain
23
- boto3
24
- python-multipart
25
- dotenv
26
- psycopg2-binary
27
- langchain_qdrant
28
- langchain_openai
29
- langchain_community
30
- langextract
31
- fastembed-gpu
32
- redis
33
- autoviz
34
- openai-agents
35
- # openpyxl
36
- # xlrd
37
- # odfpy
38
- openpyxl>=3.0.0
39
- xlrd>=2.0.1
40
- odfpy>=1.4.1
41
- # cairosvg
42
- fastapi
43
- uvicorn
44
- pyarrow
45
- pandas
46
- duckdb
47
- sqlparse
48
- pydantic
49
- requests
50
- plotly>=5.0.0
51
- boto3
52
- python-dotenv
53
- jupyter
54
- nbformat
55
- nbconvert
56
-
57
-
58
-
59
-
60
- #==================login==============
61
- annotated-types==0.7.0
62
- anyio==4.11.0
63
- bcrypt==4.2.0
64
- cffi==2.0.0
65
- click==8.3.0
66
- colorama==0.4.6
67
- cryptography==46.0.3
68
- dnspython==2.8.0
69
- ecdsa==0.19.1
70
- email_validator==2.2.0
71
- fastapi==0.119.1
72
- greenlet==3.2.4
73
- h11==0.16.0
74
- idna==3.11
75
- passlib==1.7.4
76
- psycopg2-binary==2.9.11
77
- pyasn1==0.6.1
78
- pycparser==2.23
79
- pydantic==2.12.3
80
- pydantic_core==2.41.4
81
- python-dotenv==1.1.1
82
- python-jose==3.3.0
83
- python-multipart==0.0.20
84
- rsa==4.9.1
85
- six==1.17.0
86
- sniffio==1.3.1
87
- SQLAlchemy==2.0.44
88
- starlette==0.48.0
89
- typing-inspection==0.4.2
90
- typing_extensions==4.15.0
91
- uvicorn==0.38.0
 
1
+ fastapi
2
+ uvicorn
3
+ pandas
4
+ sqlparse
5
+ duckdb
6
+ ipykernel
7
+ numpy
8
+ openai
9
+ holoviews
10
+ nbformat
11
+ nbclient
12
+ panel
13
+ bokeh
14
+ plotly
15
+ redis
16
+ IPython
17
+ langchain_experimental
18
+ google-generativeai
19
+ tabulate
20
+ redis_client
21
+ langchain
22
+ boto3
23
+ python-multipart
24
+ dotenv
25
+ psycopg2-binary
26
+ langchain_qdrant
27
+ langchain_openai
28
+ langchain_community
29
+ langextract
30
+ fastembed-gpu
31
+ redis
32
+ autoviz
33
+ openai-agents
34
+ # openpyxl
35
+ # xlrd
36
+ # odfpy
37
+ openpyxl>=3.0.0
38
+ xlrd>=2.0.1
39
+ odfpy>=1.4.1
40
+ # cairosvg
41
+ fastapi
42
+ uvicorn
43
+ altair
44
+ pyarrow
45
+ pandas
46
+ duckdb
47
+ sqlparse
48
+ pydantic
49
+ requests
50
+ plotly>=5.0.0
51
+ boto3
52
+ python-dotenv
53
+ jupyter
54
+ nbformat
55
+ nbconvert
56
+
57
+
58
+
59
+
60
+ #==================login==============
61
+ annotated-types==0.7.0
62
+ anyio==4.11.0
63
+ bcrypt==4.2.0
64
+ cffi==2.0.0
65
+ click==8.3.0
66
+ colorama==0.4.6
67
+ cryptography==46.0.3
68
+ dnspython==2.8.0
69
+ ecdsa==0.19.1
70
+ email_validator==2.2.0
71
+ fastapi==0.119.1
72
+ greenlet==3.2.4
73
+ h11==0.16.0
74
+ idna==3.11
75
+ passlib==1.7.4
76
+ psycopg2-binary==2.9.11
77
+ pyasn1==0.6.1
78
+ pycparser==2.23
79
+ pydantic==2.12.3
80
+ pydantic_core==2.41.4
81
+ python-dotenv==1.1.1
82
+ python-jose==3.3.0
83
+ python-multipart==0.0.20
84
+ rsa==4.9.1
85
+ six==1.17.0
86
+ sniffio==1.3.1
87
+ SQLAlchemy==2.0.44
88
+ starlette==0.48.0
89
+ typing-inspection==0.4.2
90
+ typing_extensions==4.15.0
91
+ uvicorn==0.38.0
s3/__pycache__/create_dataset_graphs.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
s3/__pycache__/r4.cpython-311.pyc CHANGED
Binary files a/s3/__pycache__/r4.cpython-311.pyc and b/s3/__pycache__/r4.cpython-311.pyc differ
 
s3/__pycache__/r6.cpython-311.pyc ADDED
Binary file (28 kB). View file
 
s3/__pycache__/read_files.cpython-311.pyc CHANGED
Binary files a/s3/__pycache__/read_files.cpython-311.pyc and b/s3/__pycache__/read_files.cpython-311.pyc differ
 
s3/create_dataset_graphs.py CHANGED
@@ -1,295 +1,295 @@
1
- import pandas as pd
2
- import numpy as np
3
- import json
4
- from typing import Dict, List, Any, Union
5
-
6
- def convert_to_serializable(obj):
7
- """Convert non-serializable objects to JSON-compatible types"""
8
- if pd.isna(obj):
9
- return None
10
- elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
11
- return obj.isoformat()
12
- elif isinstance(obj, (np.integer, np.int64, np.int32, np.int16, np.int8)):
13
- return int(obj)
14
- elif isinstance(obj, (np.floating, np.float64, np.float32, np.float16)):
15
- return float(obj)
16
- elif isinstance(obj, np.ndarray):
17
- return obj.tolist()
18
- elif isinstance(obj, (np.bool_, bool)):
19
- return bool(obj)
20
- elif hasattr(obj, 'isoformat'):
21
- return obj.isoformat()
22
- else:
23
- return str(obj)
24
-
25
-
26
- def analyze_column(series: pd.Series, max_rows: int = 200) -> Dict[str, Any]:
27
- """
28
- Analyze a single column and return chart-ready data
29
-
30
- Args:
31
- series: Pandas Series to analyze
32
- max_rows: Maximum rows to analyze (default: 200)
33
-
34
- Returns:
35
- Dictionary with column analysis including type and visualization data
36
- """
37
- series = series.head(max_rows).dropna()
38
-
39
- if len(series) == 0:
40
- return {
41
- 'name': series.name,
42
- 'type': 'empty',
43
- 'sampleCount': 0,
44
- 'distinctCount': 0,
45
- 'sampleValues': []
46
- }
47
-
48
- # Try numeric analysis
49
- numeric_series = pd.to_numeric(series, errors='coerce')
50
- valid_nums = numeric_series.dropna()
51
- numeric_ratio = len(valid_nums) / len(series) if len(series) > 0 else 0
52
-
53
- if numeric_ratio >= 0.6 and len(valid_nums) > 2:
54
- # Numeric column - compute histogram data
55
- min_val = float(valid_nums.min())
56
- max_val = float(valid_nums.max())
57
- range_val = max_val - min_val if max_val != min_val else 1
58
-
59
- # Create 12 bins
60
- bins = np.zeros(12, dtype=int)
61
- bin_edges = np.linspace(min_val, max_val, 13)
62
-
63
- for val in valid_nums:
64
- bin_idx = int((val - min_val) / range_val * 12)
65
- if bin_idx >= 12:
66
- bin_idx = 11
67
- bins[bin_idx] += 1
68
-
69
- bin_ranges = [[float(bin_edges[i]), float(bin_edges[i+1])] for i in range(12)]
70
-
71
- return {
72
- 'name': series.name,
73
- 'type': 'numeric',
74
- 'sampleCount': int(len(series)),
75
- 'distinctCount': int(series.nunique()),
76
- 'sampleValues': [convert_to_serializable(v) for v in series.head(5).tolist()],
77
- 'numericStats': {
78
- 'min': min_val,
79
- 'max': max_val,
80
- 'avg': float(valid_nums.mean()),
81
- 'median': float(valid_nums.median()),
82
- 'std': float(valid_nums.std()),
83
- 'bins': bins.tolist(), # For histogram bars
84
- 'binRanges': bin_ranges # For x-axis labels
85
- }
86
- }
87
-
88
- # Check if categorical
89
- unique_ratio = series.nunique() / len(series)
90
- if unique_ratio <= 0.5:
91
- # Categorical column
92
- value_counts = series.value_counts()
93
- top_values = value_counts.head(2) # Top 2 for the cards
94
-
95
- categorical_stats = []
96
- for value, count in top_values.items():
97
- percentage = (count / len(series)) * 100
98
- categorical_stats.append({
99
- 'value': convert_to_serializable(value),
100
- 'count': int(count),
101
- 'percentage': round(float(percentage), 1)
102
- })
103
-
104
- # Calculate "Other"
105
- top_sum = sum(top_values.values)
106
- other_count = len(series) - top_sum
107
- other_distinct = series.nunique() - len(top_values)
108
- other_percentage = (other_count / len(series)) * 100 if len(series) > 0 else 0
109
-
110
- categorical_other = None
111
- if other_count > 0:
112
- categorical_other = {
113
- 'count': int(other_count),
114
- 'distinct': int(other_distinct),
115
- 'percentage': round(float(other_percentage), 1)
116
- }
117
-
118
- return {
119
- 'name': series.name,
120
- 'type': 'categorical',
121
- 'sampleCount': int(len(series)),
122
- 'distinctCount': int(series.nunique()),
123
- 'sampleValues': [convert_to_serializable(v) for v in series.head(5).tolist()],
124
- 'categoricalStats': categorical_stats, # For bar charts
125
- 'categoricalOther': categorical_other
126
- }
127
-
128
- # Text column
129
- return {
130
- 'name': series.name,
131
- 'type': 'text',
132
- 'sampleCount': int(len(series)),
133
- 'distinctCount': int(series.nunique()),
134
- 'sampleValues': [convert_to_serializable(v) for v in series.head(5).tolist()],
135
- 'textStats': {
136
- 'avgLength': round(float(series.astype(str).str.len().mean()), 1),
137
- 'maxLength': int(series.astype(str).str.len().max()),
138
- 'minLength': int(series.astype(str).str.len().min())
139
- }
140
- }
141
-
142
-
143
- def create_data_set_graphs(df: pd.DataFrame, max_rows: int = 200) -> str:
144
- """
145
- Analyze a DataFrame and generate JSON for frontend column visualization cards
146
-
147
- Args:
148
- df: Pandas DataFrame to analyze
149
- max_rows: Maximum rows to analyze per column (default: 200)
150
-
151
- Returns:
152
- JSON string containing column summaries with visualization data
153
-
154
- Example:
155
- >>> df = pd.read_csv('data.csv')
156
- >>> json_output = create_data_set_graphs(df)
157
- >>> print(json_output)
158
-
159
- JSON Structure:
160
- {
161
- "columnSummaries": [
162
- {
163
- "name": "column_name",
164
- "type": "numeric|categorical|text|empty",
165
- "sampleCount": 1000,
166
- "distinctCount": 45,
167
- "sampleValues": [...],
168
- "numericStats": {...}, // For numeric columns
169
- "categoricalStats": [...], // For categorical columns
170
- "textStats": {...} // For text columns
171
- }
172
- ]
173
- }
174
- """
175
- try:
176
- # Analyze all columns
177
- column_summaries = []
178
- for col in df.columns:
179
- analysis = analyze_column(df[col], max_rows=max_rows)
180
- column_summaries.append(analysis)
181
-
182
- # Create the JSON output
183
- json_output = {
184
- "columnSummaries": column_summaries
185
- }
186
-
187
- # Convert to JSON string
188
- return json.dumps(json_output, indent=2)
189
-
190
- except Exception as e:
191
- # Return error as JSON
192
- error_output = {
193
- "error": str(e),
194
- "columnSummaries": []
195
- }
196
- return json.dumps(error_output, indent=2)
197
-
198
-
199
- def create_data_set_graphs_dict(df: pd.DataFrame, max_rows: int = 200) -> Dict[str, Any]:
200
- """
201
- Analyze a DataFrame and generate dictionary for frontend column visualization cards
202
-
203
- Args:
204
- df: Pandas DataFrame to analyze
205
- max_rows: Maximum rows to analyze per column (default: 200)
206
-
207
- Returns:
208
- Dictionary containing column summaries with visualization data
209
- """
210
- try:
211
- # Analyze all columns
212
- column_summaries = []
213
- for col in df.columns:
214
- analysis = analyze_column(df[col], max_rows=max_rows)
215
- column_summaries.append(analysis)
216
-
217
- # Return as dictionary
218
- return {
219
- "columnSummaries": column_summaries
220
- }
221
-
222
- except Exception as e:
223
- # Return error as dict
224
- return {
225
- "error": str(e),
226
- "columnSummaries": []
227
- }
228
-
229
-
230
- # # Example usage
231
- # if __name__ == "__main__":
232
- # # Example 1: Create sample data
233
- # sample_df = pd.DataFrame({
234
- # 'id': range(1, 101),
235
- # 'name': [f'Person {i}' for i in range(1, 101)],
236
- # 'age': np.random.randint(20, 60, 100),
237
- # 'salary': np.random.uniform(30000, 150000, 100),
238
- # 'department': np.random.choice(['Sales', 'Engineering', 'Marketing', 'HR'], 100),
239
- # 'email': [f'person{i}@example.com' for i in range(1, 101)],
240
- # 'join_date': pd.date_range('2020-01-01', periods=100, freq='3D')
241
- # })
242
-
243
- # print("=" * 80)
244
- # print("Example 1: Generate JSON string")
245
- # print("=" * 80)
246
-
247
- # # Generate JSON string
248
- # json_output = create_data_set_graphs(sample_df)
249
- # print(json_output)
250
-
251
- # # Save to file
252
- # with open('column_insights.json', 'w') as f:
253
- # f.write(json_output)
254
- # print("\n✅ JSON saved to 'column_insights.json'")
255
-
256
- # print("\n" + "=" * 80)
257
- # print("Example 2: Generate dictionary")
258
- # print("=" * 80)
259
-
260
- # # Generate dictionary
261
- # dict_output = create_data_set_graphs_dict(sample_df)
262
- # print(f"Found {len(dict_output['columnSummaries'])} columns")
263
-
264
- # # Print first column summary
265
- # if dict_output['columnSummaries']:
266
- # print("\nFirst column summary:")
267
- # print(json.dumps(dict_output['columnSummaries'][0], indent=2))
268
-
269
- # print("\n" + "=" * 80)
270
- # print("Example 3: Load from CSV file")
271
- # print("=" * 80)
272
-
273
- # # If you have a CSV file
274
- # # df = pd.read_csv('your_file.csv')
275
- # # json_output = create_data_set_graphs(df)
276
- # # print(json_output)
277
-
278
- # print("""
279
- # Usage in your code:
280
-
281
- # # From CSV
282
- # df = pd.read_csv('data.csv')
283
- # json_result = create_data_set_graphs(df)
284
-
285
- # # From Excel
286
- # df = pd.read_excel('data.xlsx')
287
- # json_result = create_data_set_graphs(df)
288
-
289
- # # As dictionary (for further processing)
290
- # dict_result = create_data_set_graphs_dict(df)
291
-
292
- # # Access column summaries
293
- # for column in dict_result['columnSummaries']:
294
- # print(f"Column: {column['name']}, Type: {column['type']}")
295
  # """)
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import json
4
+ from typing import Dict, List, Any, Union
5
+
6
+ def convert_to_serializable(obj):
7
+ """Convert non-serializable objects to JSON-compatible types"""
8
+ if pd.isna(obj):
9
+ return None
10
+ elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
11
+ return obj.isoformat()
12
+ elif isinstance(obj, (np.integer, np.int64, np.int32, np.int16, np.int8)):
13
+ return int(obj)
14
+ elif isinstance(obj, (np.floating, np.float64, np.float32, np.float16)):
15
+ return float(obj)
16
+ elif isinstance(obj, np.ndarray):
17
+ return obj.tolist()
18
+ elif isinstance(obj, (np.bool_, bool)):
19
+ return bool(obj)
20
+ elif hasattr(obj, 'isoformat'):
21
+ return obj.isoformat()
22
+ else:
23
+ return str(obj)
24
+
25
+
26
+ def analyze_column(series: pd.Series, max_rows: int = 200) -> Dict[str, Any]:
27
+ """
28
+ Analyze a single column and return chart-ready data
29
+
30
+ Args:
31
+ series: Pandas Series to analyze
32
+ max_rows: Maximum rows to analyze (default: 200)
33
+
34
+ Returns:
35
+ Dictionary with column analysis including type and visualization data
36
+ """
37
+ series = series.head(max_rows).dropna()
38
+
39
+ if len(series) == 0:
40
+ return {
41
+ 'name': series.name,
42
+ 'type': 'empty',
43
+ 'sampleCount': 0,
44
+ 'distinctCount': 0,
45
+ 'sampleValues': []
46
+ }
47
+
48
+ # Try numeric analysis
49
+ numeric_series = pd.to_numeric(series, errors='coerce')
50
+ valid_nums = numeric_series.dropna()
51
+ numeric_ratio = len(valid_nums) / len(series) if len(series) > 0 else 0
52
+
53
+ if numeric_ratio >= 0.6 and len(valid_nums) > 2:
54
+ # Numeric column - compute histogram data
55
+ min_val = float(valid_nums.min())
56
+ max_val = float(valid_nums.max())
57
+ range_val = max_val - min_val if max_val != min_val else 1
58
+
59
+ # Create 12 bins
60
+ bins = np.zeros(12, dtype=int)
61
+ bin_edges = np.linspace(min_val, max_val, 13)
62
+
63
+ for val in valid_nums:
64
+ bin_idx = int((val - min_val) / range_val * 12)
65
+ if bin_idx >= 12:
66
+ bin_idx = 11
67
+ bins[bin_idx] += 1
68
+
69
+ bin_ranges = [[float(bin_edges[i]), float(bin_edges[i+1])] for i in range(12)]
70
+
71
+ return {
72
+ 'name': series.name,
73
+ 'type': 'numeric',
74
+ 'sampleCount': int(len(series)),
75
+ 'distinctCount': int(series.nunique()),
76
+ 'sampleValues': [convert_to_serializable(v) for v in series.head(5).tolist()],
77
+ 'numericStats': {
78
+ 'min': min_val,
79
+ 'max': max_val,
80
+ 'avg': float(valid_nums.mean()),
81
+ 'median': float(valid_nums.median()),
82
+ 'std': float(valid_nums.std()),
83
+ 'bins': bins.tolist(), # For histogram bars
84
+ 'binRanges': bin_ranges # For x-axis labels
85
+ }
86
+ }
87
+
88
+ # Check if categorical
89
+ unique_ratio = series.nunique() / len(series)
90
+ if unique_ratio <= 0.5:
91
+ # Categorical column
92
+ value_counts = series.value_counts()
93
+ top_values = value_counts.head(2) # Top 2 for the cards
94
+
95
+ categorical_stats = []
96
+ for value, count in top_values.items():
97
+ percentage = (count / len(series)) * 100
98
+ categorical_stats.append({
99
+ 'value': convert_to_serializable(value),
100
+ 'count': int(count),
101
+ 'percentage': round(float(percentage), 1)
102
+ })
103
+
104
+ # Calculate "Other"
105
+ top_sum = sum(top_values.values)
106
+ other_count = len(series) - top_sum
107
+ other_distinct = series.nunique() - len(top_values)
108
+ other_percentage = (other_count / len(series)) * 100 if len(series) > 0 else 0
109
+
110
+ categorical_other = None
111
+ if other_count > 0:
112
+ categorical_other = {
113
+ 'count': int(other_count),
114
+ 'distinct': int(other_distinct),
115
+ 'percentage': round(float(other_percentage), 1)
116
+ }
117
+
118
+ return {
119
+ 'name': series.name,
120
+ 'type': 'categorical',
121
+ 'sampleCount': int(len(series)),
122
+ 'distinctCount': int(series.nunique()),
123
+ 'sampleValues': [convert_to_serializable(v) for v in series.head(5).tolist()],
124
+ 'categoricalStats': categorical_stats, # For bar charts
125
+ 'categoricalOther': categorical_other
126
+ }
127
+
128
+ # Text column
129
+ return {
130
+ 'name': series.name,
131
+ 'type': 'text',
132
+ 'sampleCount': int(len(series)),
133
+ 'distinctCount': int(series.nunique()),
134
+ 'sampleValues': [convert_to_serializable(v) for v in series.head(5).tolist()],
135
+ 'textStats': {
136
+ 'avgLength': round(float(series.astype(str).str.len().mean()), 1),
137
+ 'maxLength': int(series.astype(str).str.len().max()),
138
+ 'minLength': int(series.astype(str).str.len().min())
139
+ }
140
+ }
141
+
142
+
143
+ def create_data_set_graphs(df: pd.DataFrame, max_rows: int = 200) -> str:
144
+ """
145
+ Analyze a DataFrame and generate JSON for frontend column visualization cards
146
+
147
+ Args:
148
+ df: Pandas DataFrame to analyze
149
+ max_rows: Maximum rows to analyze per column (default: 200)
150
+
151
+ Returns:
152
+ JSON string containing column summaries with visualization data
153
+
154
+ Example:
155
+ >>> df = pd.read_csv('data.csv')
156
+ >>> json_output = create_data_set_graphs(df)
157
+ >>> print(json_output)
158
+
159
+ JSON Structure:
160
+ {
161
+ "columnSummaries": [
162
+ {
163
+ "name": "column_name",
164
+ "type": "numeric|categorical|text|empty",
165
+ "sampleCount": 1000,
166
+ "distinctCount": 45,
167
+ "sampleValues": [...],
168
+ "numericStats": {...}, // For numeric columns
169
+ "categoricalStats": [...], // For categorical columns
170
+ "textStats": {...} // For text columns
171
+ }
172
+ ]
173
+ }
174
+ """
175
+ try:
176
+ # Analyze all columns
177
+ column_summaries = []
178
+ for col in df.columns:
179
+ analysis = analyze_column(df[col], max_rows=max_rows)
180
+ column_summaries.append(analysis)
181
+
182
+ # Create the JSON output
183
+ json_output = {
184
+ "columnSummaries": column_summaries
185
+ }
186
+
187
+ # Convert to JSON string
188
+ return json.dumps(json_output, indent=2)
189
+
190
+ except Exception as e:
191
+ # Return error as JSON
192
+ error_output = {
193
+ "error": str(e),
194
+ "columnSummaries": []
195
+ }
196
+ return json.dumps(error_output, indent=2)
197
+
198
+
199
+ def create_data_set_graphs_dict(df: pd.DataFrame, max_rows: int = 200) -> Dict[str, Any]:
200
+ """
201
+ Analyze a DataFrame and generate dictionary for frontend column visualization cards
202
+
203
+ Args:
204
+ df: Pandas DataFrame to analyze
205
+ max_rows: Maximum rows to analyze per column (default: 200)
206
+
207
+ Returns:
208
+ Dictionary containing column summaries with visualization data
209
+ """
210
+ try:
211
+ # Analyze all columns
212
+ column_summaries = []
213
+ for col in df.columns:
214
+ analysis = analyze_column(df[col], max_rows=max_rows)
215
+ column_summaries.append(analysis)
216
+
217
+ # Return as dictionary
218
+ return {
219
+ "columnSummaries": column_summaries
220
+ }
221
+
222
+ except Exception as e:
223
+ # Return error as dict
224
+ return {
225
+ "error": str(e),
226
+ "columnSummaries": []
227
+ }
228
+
229
+
230
+ # # Example usage
231
+ # if __name__ == "__main__":
232
+ # # Example 1: Create sample data
233
+ # sample_df = pd.DataFrame({
234
+ # 'id': range(1, 101),
235
+ # 'name': [f'Person {i}' for i in range(1, 101)],
236
+ # 'age': np.random.randint(20, 60, 100),
237
+ # 'salary': np.random.uniform(30000, 150000, 100),
238
+ # 'department': np.random.choice(['Sales', 'Engineering', 'Marketing', 'HR'], 100),
239
+ # 'email': [f'person{i}@example.com' for i in range(1, 101)],
240
+ # 'join_date': pd.date_range('2020-01-01', periods=100, freq='3D')
241
+ # })
242
+
243
+ # print("=" * 80)
244
+ # print("Example 1: Generate JSON string")
245
+ # print("=" * 80)
246
+
247
+ # # Generate JSON string
248
+ # json_output = create_data_set_graphs(sample_df)
249
+ # print(json_output)
250
+
251
+ # # Save to file
252
+ # with open('column_insights.json', 'w') as f:
253
+ # f.write(json_output)
254
+ # print("\n✅ JSON saved to 'column_insights.json'")
255
+
256
+ # print("\n" + "=" * 80)
257
+ # print("Example 2: Generate dictionary")
258
+ # print("=" * 80)
259
+
260
+ # # Generate dictionary
261
+ # dict_output = create_data_set_graphs_dict(sample_df)
262
+ # print(f"Found {len(dict_output['columnSummaries'])} columns")
263
+
264
+ # # Print first column summary
265
+ # if dict_output['columnSummaries']:
266
+ # print("\nFirst column summary:")
267
+ # print(json.dumps(dict_output['columnSummaries'][0], indent=2))
268
+
269
+ # print("\n" + "=" * 80)
270
+ # print("Example 3: Load from CSV file")
271
+ # print("=" * 80)
272
+
273
+ # # If you have a CSV file
274
+ # # df = pd.read_csv('your_file.csv')
275
+ # # json_output = create_data_set_graphs(df)
276
+ # # print(json_output)
277
+
278
+ # print("""
279
+ # Usage in your code:
280
+
281
+ # # From CSV
282
+ # df = pd.read_csv('data.csv')
283
+ # json_result = create_data_set_graphs(df)
284
+
285
+ # # From Excel
286
+ # df = pd.read_excel('data.xlsx')
287
+ # json_result = create_data_set_graphs(df)
288
+
289
+ # # As dictionary (for further processing)
290
+ # dict_result = create_data_set_graphs_dict(df)
291
+
292
+ # # Access column summaries
293
+ # for column in dict_result['columnSummaries']:
294
+ # print(f"Column: {column['name']}, Type: {column['type']}")
295
  # """)
s3/r4.py CHANGED
@@ -1,1057 +1,853 @@
1
- # from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
2
- # from fastapi.responses import JSONResponse
3
- # import pandas as pd
4
- # from autoviz.AutoViz_Class import AutoViz_Class
5
- # import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
6
- # matplotlib.use('Agg')
7
- # import matplotlib.pyplot as plt
8
- # import sys
9
- # from pathlib import Path
10
- # from typing import List
11
- # import httpx
12
-
13
- # # --- Project Root Setup ---
14
- # PROJECT_ROOT = Path(__file__).resolve().parents[1]
15
- # if str(PROJECT_ROOT) not in sys.path:
16
- # sys.path.insert(0, str(PROJECT_ROOT))
17
-
18
- # from retrieve_secret import *
19
- # from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
20
- # from s3.create_dataset_graphs import create_data_set_graphs_dict
21
-
22
- # # --- File Validation ---
23
- # MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
24
- # MAX_ROWS_ALLOWED = 1_000_000
25
- # ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
26
-
27
- # # --- AWS S3 Config ---
28
- # print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
29
- # ACCESS_KEY = AWS_S3_CREDS_KEY_ID
30
- # SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
31
- # BUCKET_NAME = BUCKET_NAME
32
- # REGION_NAME = "us-east-1"
33
-
34
- # s3 = boto3.client(
35
- # "s3",
36
- # aws_access_key_id=ACCESS_KEY,
37
- # aws_secret_access_key=SECRET_KEY,
38
- # region_name=REGION_NAME
39
- # )
40
-
41
- # ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
42
-
43
- # # --- FastAPI Router ---
44
- # s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
45
-
46
-
47
- # # --- Helper: S3 Key ---
48
- # def make_key(path: str, filename: str) -> str:
49
- # return f"{path.strip('/')}/{filename}" if path else filename
50
-
51
-
52
- # # --- Sanitize for JSON ---
53
- # def sanitize_for_json(obj):
54
- # if isinstance(obj, dict):
55
- # return {k: sanitize_for_json(v) for k, v in obj.items()}
56
- # elif isinstance(obj, list):
57
- # return [sanitize_for_json(item) for item in obj]
58
- # elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
59
- # return str(obj)
60
- # elif pd.isna(obj):
61
- # return None
62
- # elif isinstance(obj, (int, float, str, bool, type(None))):
63
- # return obj
64
- # else:
65
- # return str(obj)
66
-
67
-
68
- # # --- Vector DB Placeholders ---
69
- # def check_vdb(user_id: str):
70
- # print(f"Checking VDB for user: {user_id}")
71
-
72
- # async def add_metadata_only(collection_name: str, metadata: dict):
73
- # print(f"Adding metadata to collection: {collection_name}")
74
- # return {"status": "success", "collection": collection_name}
75
-
76
-
77
- # # --- AutoViz HTML ---
78
- # def run_autoviz_html(
79
- # dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
80
- # max_rows_analyzed=150000, max_cols_analyzed=30
81
- # ):
82
- # tmp_dir = tempfile.mkdtemp()
83
- # AV = AutoViz_Class()
84
- # AV.AutoViz(
85
- # filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
86
- # verbose=verbose, lowess=lowess, chart_format="html",
87
- # max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
88
- # save_plot_dir=tmp_dir
89
- # )
90
- # print(f"HTML plots saved to: {tmp_dir}")
91
- # return tmp_dir
92
-
93
-
94
- # # --- AutoViz Bokeh ---
95
- # def run_autoviz_bokeh(
96
- # dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
97
- # max_rows_analyzed=150000, max_cols_analyzed=30
98
- # ):
99
- # tmp_dir = tempfile.mkdtemp()
100
- # print(f"Bokeh temp directory: {tmp_dir}")
101
- # try:
102
- # matplotlib.use('Agg')
103
- # plt.ioff()
104
- # AV = AutoViz_Class()
105
- # AV.AutoViz(
106
- # filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
107
- # verbose=verbose, lowess=lowess, chart_format="bokeh",
108
- # max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
109
- # save_plot_dir=tmp_dir
110
- # )
111
- # plt.close('all')
112
- # print(f"Bokeh plots saved to: {tmp_dir}")
113
- # return tmp_dir
114
- # except Exception as e:
115
- # print(f"Error in Bokeh: {e}")
116
- # import traceback
117
- # traceback.print_exc()
118
- # raise
119
-
120
-
121
- # # --- Fallback Matplotlib ---
122
- # def generate_matplotlib_plots(df, user_id, base_filename):
123
- # print("\nGenerating matplotlib fallback plots...")
124
- # tmp_dir = tempfile.mkdtemp()
125
- # plot_metadata = []
126
- # try:
127
- # matplotlib.use('Agg')
128
- # numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
129
- # if len(numeric_cols) == 0:
130
- # print("No numeric columns for plotting")
131
- # return []
132
-
133
- # # Correlation heatmap
134
- # if len(numeric_cols) > 1:
135
- # plt.figure(figsize=(10, 8))
136
- # import seaborn as sns
137
- # corr = df[numeric_cols].corr()
138
- # sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
139
- # plt.title('Correlation Heatmap')
140
- # plt.tight_layout()
141
- # png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
142
- # plt.savefig(png_path, dpi=100, bbox_inches='tight')
143
- # plt.close()
144
-
145
- # s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
146
- # with open(png_path, 'rb') as f:
147
- # s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
148
- # plot_metadata.append({
149
- # "file_name": "correlation_heatmap.png",
150
- # "s3_path": s3_key,
151
- # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
152
- # "type": "png",
153
- # "plot_type": "heatmap"
154
- # })
155
-
156
- # # Histograms
157
- # for col in numeric_cols[:5]:
158
- # plt.figure(figsize=(10, 6))
159
- # df[col].hist(bins=30, edgecolor='black')
160
- # plt.title(f'Distribution of {col}')
161
- # plt.xlabel(col)
162
- # plt.ylabel('Frequency')
163
- # plt.tight_layout()
164
- # png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
165
- # plt.savefig(png_path, dpi=100, bbox_inches='tight')
166
- # plt.close()
167
-
168
- # s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
169
- # with open(png_path, 'rb') as f:
170
- # s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
171
- # plot_metadata.append({
172
- # "file_name": f"distribution_{col}.png",
173
- # "s3_path": s3_key,
174
- # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
175
- # "type": "png",
176
- # "plot_type": "histogram"
177
- # })
178
- # return plot_metadata
179
- # except Exception as e:
180
- # print(f"Matplotlib failed: {e}")
181
- # return []
182
- # finally:
183
- # shutil.rmtree(tmp_dir, ignore_errors=True)
184
-
185
-
186
- # # --- Upload Viz Files ---
187
- # def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
188
- # patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
189
- # files = []
190
- # for p in patterns:
191
- # files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
192
- # files = list(set(files))
193
- # if not files:
194
- # print(f"No {viz_type} files found")
195
- # return []
196
-
197
- # folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
198
- # metadata = []
199
- # content_type_map = {
200
- # 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
201
- # 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
202
- # }
203
-
204
- # for file_path in files:
205
- # name = os.path.basename(file_path)
206
- # ext = os.path.splitext(name)[1][1:]
207
- # s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
208
- # with open(file_path, "rb") as f:
209
- # body = f.read()
210
- # s3.put_object(
211
- # Bucket=BUCKET_NAME, Key=s3_key, Body=body,
212
- # ContentType=content_type_map.get(ext, 'application/octet-stream')
213
- # )
214
- # metadata.append({
215
- # "file_name": name, "s3_path": s3_key,
216
- # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
217
- # "type": viz_type, "size": len(body)
218
- # })
219
- # print(f"{viz_type.upper()} uploaded: {s3_key}")
220
- # return metadata
221
-
222
-
223
- # # --- Convert to Parquet ---
224
- # def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
225
- # buffer = io.BytesIO()
226
- # df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
227
- # buffer.seek(0)
228
- # return buffer
229
-
230
-
231
- # # --- Check File Hash Exists (FIXED) ---
232
- # async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
233
- # """
234
- # Check if a file hash already exists for a user.
235
- # Returns a dict with 'success', 'exists', and optional 'data' keys.
236
- # """
237
- # url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
238
- # headers = {"accept": "application/json"}
239
- # async with httpx.AsyncClient(timeout=10.0) as client:
240
- # try:
241
- # r = await client.get(url, headers=headers)
242
- # r.raise_for_status()
243
- # result = r.json()
244
-
245
- # # Log the actual response for debugging
246
- # print(f"Hash check API response: {result}")
247
-
248
- # # Check for 'exists' field first, then parse message if not present
249
- # exists = result.get("exists", None)
250
-
251
- # # If 'exists' field is not present, parse the message
252
- # if exists is None and "message" in result:
253
- # message_lower = result["message"].lower()
254
- # # Check if message indicates file already exists
255
- # exists = "already existed" in message_lower or "already exists" in message_lower or "duplicate" in message_lower
256
-
257
- # # Default to False if still None
258
- # if exists is None:
259
- # exists = False
260
-
261
- # return {
262
- # "success": True,
263
- # "exists": exists,
264
- # "data": result
265
- # }
266
- # except httpx.HTTPStatusError as e:
267
- # print(f"Hash check HTTP error: {e.response.status_code} - {e.response.text}")
268
- # return {
269
- # "success": False,
270
- # "exists": False, # Assume not exists on error to be safe
271
- # "message": f"HTTP {e.response.status_code}",
272
- # "error_detail": e.response.text
273
- # }
274
- # except Exception as e:
275
- # print(f"Hash check exception: {str(e)}")
276
- # return {
277
- # "success": False,
278
- # "exists": False, # Assume not exists on error
279
- # "message": f"Request failed: {str(e)}"
280
- # }
281
-
282
-
283
- # # --- PostgreSQL Metadata Upload with file_hash ---
284
- # async def user_metadata_upload_pg(
285
- # user_id: str,
286
- # user_metadata: str,
287
- # path: str,
288
- # url: str,
289
- # filename: str,
290
- # file_type: str,
291
- # file_size_bytes: int,
292
- # file_hash: str,
293
- # timeout: float = 10.0
294
- # ):
295
- # payload = {
296
- # "user_id": user_id,
297
- # "user_metadata": user_metadata,
298
- # "path": path,
299
- # "url": url,
300
- # "filename": filename,
301
- # "file_type": file_type,
302
- # "file_size_bytes": file_size_bytes,
303
- # "file_hash": file_hash
304
- # }
305
- # print(f"PostgreSQL payload file_hash: {payload['file_hash']}")
306
- # async with httpx.AsyncClient() as client:
307
- # try:
308
- # r = await client.post(
309
- # "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
310
- # json=payload,
311
- # timeout=timeout
312
- # )
313
- # r.raise_for_status()
314
- # result = r.json()
315
- # result["file_hash"] = file_hash
316
- # print(f"PostgreSQL result file_hash: {result['file_hash']}")
317
- # return {"success": True, "data": result}
318
- # except httpx.HTTPStatusError as e:
319
- # return {
320
- # "success": False,
321
- # "error": "HTTP error",
322
- # "status_code": e.response.status_code,
323
- # "detail": e.response.text,
324
- # "file_hash": file_hash
325
- # }
326
- # except Exception as e:
327
- # return {
328
- # "success": False,
329
- # "error": "Request failed",
330
- # "detail": str(e),
331
- # "file_hash": file_hash
332
- # }
333
-
334
-
335
- # # --- DEBUG ENDPOINT (Optional - for troubleshooting) ---
336
- # @s3_bucket_router1.get("/debug/check_hash/{user_id}/{file_hash}")
337
- # async def debug_check_hash(user_id: str, file_hash: str):
338
- # """Debug endpoint to test hash checking"""
339
- # result = await check_file_hash_exists(user_id, file_hash)
340
- # return {
341
- # "raw_result": result,
342
- # "exists_value": result.get("exists"),
343
- # "exists_type": type(result.get("exists")).__name__,
344
- # "success_value": result.get("success"),
345
- # "interpretation": "File exists" if result.get("exists") is True else "File does not exist or check failed"
346
- # }
347
-
348
-
349
- # # --- MAIN ENDPOINT WITH FIXED HASH CHECK AND DATASET GRAPHS ---
350
- # @s3_bucket_router1.post("/upload_datasets_v3/")
351
- # async def upload_file(
352
- # file: UploadFile = File(...),
353
- # user_id: str = Query(..., description="User ID"),
354
- # path: str = Query("", description="Optional subpath")
355
- # ):
356
- # html_tmp_dir = None
357
- # bokeh_tmp_dir = None
358
- # file_content = None
359
-
360
- # try:
361
- # # 1. Validate extension
362
- # file_ext = os.path.splitext(file.filename)[1].lower()
363
- # if file_ext not in ALLOWED_EXTENSIONS:
364
- # raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_ext}")
365
-
366
- # # 2. Read file
367
- # file_content = await file.read()
368
- # if not file_content:
369
- # raise HTTPException(status_code=400, detail="Empty file")
370
- # if len(file_content) > MAX_FILE_SIZE_BYTES:
371
- # raise HTTPException(status_code=413, detail="File exceeds 100 MiB limit")
372
-
373
- # # 3. Generate hash
374
- # file_hash = hashlib.sha256(file_content).hexdigest()
375
- # print(f"Generated file hash: {file_hash}")
376
-
377
- # # 4. Check hash via API (FIXED LOGIC)
378
- # hash_result = await check_file_hash_exists(user_id, file_hash)
379
-
380
- # # Enhanced logging
381
- # print(f"Hash check result: {hash_result}")
382
-
383
- # # Check if the API call was successful first
384
- # if not hash_result.get("success", False):
385
- # print(f"⚠️ Warning: Hash check API failed: {hash_result.get('message')}")
386
- # # You can decide to fail here if hash check is critical
387
- # # raise HTTPException(status_code=503, detail="Hash check service unavailable")
388
-
389
- # # Check if file exists
390
- # if hash_result.get("exists") is True:
391
- # print(f"🚫 Duplicate file detected: {file_hash}")
392
- # return JSONResponse(
393
- # status_code=409,
394
- # content={
395
- # "message": "File already uploaded.",
396
- # "reason": "Duplicate file detected via SHA-256 hash.",
397
- # "file_hash": file_hash,
398
- # "user_id": user_id,
399
- # "filename": file.filename,
400
- # "action": "skipped",
401
- # "existing_file_info": hash_result.get("data")
402
- # }
403
- # )
404
-
405
- # print("✅ Hash check passed. New file - proceeding with upload.")
406
-
407
- # # 5. Load DataFrame
408
- # try:
409
- # if file_ext == ".csv":
410
- # df = pd.read_csv(io.BytesIO(file_content))
411
- # elif file_ext in {".xlsx", ".xls"}:
412
- # engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
413
- # df = pd.read_excel(io.BytesIO(file_content), engine=engine)
414
- # elif file_ext == ".ods":
415
- # df = pd.read_excel(io.BytesIO(file_content), engine='odf')
416
- # except Exception as e:
417
- # raise HTTPException(status_code=400, detail=f"Failed to parse file: {str(e)}")
418
-
419
- # if len(df) > MAX_ROWS_ALLOWED:
420
- # raise HTTPException(status_code=413, detail=f"Too many rows: {len(df):,} > {MAX_ROWS_ALLOWED:,}")
421
-
422
- # # 6. Convert to Parquet
423
- # parquet_buffer = convert_df_to_parquet(df)
424
- # parquet_size = parquet_buffer.getbuffer().nbytes
425
-
426
- # # 7. Upload Parquet to S3
427
- # base_filename = os.path.splitext(file.filename)[0]
428
- # parquet_filename = f"{base_filename}.parquet"
429
- # file_key = f"{user_id}/files/datasets/{parquet_filename}"
430
- # file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
431
-
432
- # s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
433
- # ExtraArgs={'ContentType': 'application/octet-stream'})
434
- # print(f"Uploaded Parquet: {file_url}")
435
-
436
- # # 8. Generate metadata
437
- # metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
438
- # metadata.update({
439
- # "user_id": user_id,
440
- # "s3_path": file_key,
441
- # "s3_url": file_url,
442
- # "source_file": file.filename,
443
- # "source_file_type": file_ext[1:],
444
- # "file_type": "parquet",
445
- # "original_file_size_bytes": len(file_content),
446
- # "parquet_file_size_bytes": parquet_size,
447
- # "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
448
- # "file_hash": file_hash
449
- # })
450
-
451
- # # 8.5 Generate dataset preview graphs
452
- # print("Generating dataset preview graphs...")
453
- # try:
454
- # dataset_graphs = create_data_set_graphs_dict(df, max_rows=200)
455
- # metadata["data_sets_preview_graph"] = dataset_graphs
456
- # print(f"✅ Generated graphs for {len(dataset_graphs.get('columnSummaries', []))} columns")
457
- # except Exception as e:
458
- # print(f"⚠️ Failed to generate dataset graphs: {e}")
459
- # import traceback
460
- # traceback.print_exc()
461
- # metadata["data_sets_preview_graph"] = {
462
- # "error": str(e),
463
- # "columnSummaries": []
464
- # }
465
-
466
- # safe_metadata = sanitize_for_json(metadata)
467
-
468
- # # 9. Vector DB
469
- # check_vdb(user_id)
470
- # vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
471
- # vdb_success = vdb_res.get("status") == "success"
472
- # print(f"VDB upload success: {vdb_success}")
473
-
474
- # # 10. PostgreSQL Metadata + file_hash
475
- # pg_result = await user_metadata_upload_pg(
476
- # user_id=user_id,
477
- # user_metadata=json.dumps(safe_metadata),
478
- # path=file_key,
479
- # url=file_url,
480
- # filename=parquet_filename,
481
- # file_type="parquet",
482
- # file_size_bytes=parquet_size,
483
- # file_hash=file_hash
484
- # )
485
- # print(f"PostgreSQL upload result: {pg_result}")
486
- # pg_success = pg_result.get("success", False)
487
-
488
- # graphs_count = len(safe_metadata.get("data_sets_preview_graph", {}).get("columnSummaries", []))
489
- # print(f"Graphs generated: {graphs_count}")
490
-
491
- # # 11. Return success
492
- # return {
493
- # "message": "Upload successful.",
494
- # "filename": parquet_filename,
495
- # "original_filename": file.filename,
496
- # "user_id": user_id,
497
- # "file_path": file_key,
498
- # "file_url": file_url,
499
- # "file_hash": file_hash,
500
- # "source_file_type": file_ext[1:],
501
- # "file_type": "parquet",
502
- # "original_file_size_bytes": len(file_content),
503
- # "parquet_file_size_bytes": parquet_size,
504
- # "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
505
- # "rows": len(df),
506
- # "columns": len(df.columns),
507
- # "metadata": safe_metadata,
508
- # "upload_dataset_vdb": vdb_success,
509
- # "upload_dataset_pg": pg_success,
510
- # "pg_details": pg_result,
511
- # "graphs_generated": graphs_count
512
- # }
513
-
514
- # except HTTPException:
515
- # raise
516
- # except Exception as e:
517
- # print(f"Unexpected error: {e}")
518
- # import traceback
519
- # traceback.print_exc()
520
- # raise HTTPException(status_code=500, detail=str(e))
521
- # finally:
522
- # # Clean up temp directories
523
- # for d in (html_tmp_dir, bokeh_tmp_dir):
524
- # if d and os.path.exists(d):
525
- # shutil.rmtree(d, ignore_errors=True)
526
-
527
-
528
-
529
-
530
-
531
- from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
532
- from fastapi.responses import JSONResponse
533
- import pandas as pd
534
- from autoviz.AutoViz_Class import AutoViz_Class
535
- import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
536
- matplotlib.use('Agg')
537
- import matplotlib.pyplot as plt
538
- import sys
539
- from pathlib import Path
540
- from typing import List
541
- import httpx
542
-
543
- # --- Project Root Setup ---
544
- PROJECT_ROOT = Path(__file__).resolve().parents[1]
545
- if str(PROJECT_ROOT) not in sys.path:
546
- sys.path.insert(0, str(PROJECT_ROOT))
547
-
548
- from retrieve_secret import *
549
- from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
550
- from s3.create_dataset_graphs import create_data_set_graphs_dict
551
-
552
- # --- File Validation ---
553
- MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
554
- MAX_ROWS_ALLOWED = 1_000_000
555
- ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
556
-
557
- # --- AWS S3 Config ---
558
- print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
559
- ACCESS_KEY = AWS_S3_CREDS_KEY_ID
560
- SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
561
- BUCKET_NAME = BUCKET_NAME
562
- REGION_NAME = "us-east-1"
563
-
564
- s3 = boto3.client(
565
- "s3",
566
- aws_access_key_id=ACCESS_KEY,
567
- aws_secret_access_key=SECRET_KEY,
568
- region_name=REGION_NAME
569
- )
570
-
571
- ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
572
-
573
- # --- FastAPI Router ---
574
- s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
575
-
576
-
577
- # --- Helper: S3 Key ---
578
- def make_key(path: str, filename: str) -> str:
579
- return f"{path.strip('/')}/{filename}" if path else filename
580
-
581
-
582
- # --- Sanitize for JSON ---
583
- def sanitize_for_json(obj):
584
- if isinstance(obj, dict):
585
- return {k: sanitize_for_json(v) for k, v in obj.items()}
586
- elif isinstance(obj, list):
587
- return [sanitize_for_json(item) for item in obj]
588
- elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
589
- return str(obj)
590
- elif pd.isna(obj):
591
- return None
592
- elif isinstance(obj, (int, float, str, bool, type(None))):
593
- return obj
594
- else:
595
- return str(obj)
596
-
597
-
598
- # --- Vector DB Placeholders ---
599
- def check_vdb(user_id: str):
600
- print(f"Checking VDB for user: {user_id}")
601
-
602
- async def add_metadata_only(collection_name: str, metadata: dict):
603
- print(f"Adding metadata to collection: {collection_name}")
604
- return {"status": "success", "collection": collection_name}
605
-
606
-
607
- # --- AutoViz HTML ---
608
- def run_autoviz_html(
609
- dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
610
- max_rows_analyzed=150000, max_cols_analyzed=30
611
- ):
612
- tmp_dir = tempfile.mkdtemp()
613
- AV = AutoViz_Class()
614
- AV.AutoViz(
615
- filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
616
- verbose=verbose, lowess=lowess, chart_format="html",
617
- max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
618
- save_plot_dir=tmp_dir
619
- )
620
- print(f"HTML plots saved to: {tmp_dir}")
621
- return tmp_dir
622
-
623
-
624
- # --- AutoViz Bokeh ---
625
- def run_autoviz_bokeh(
626
- dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
627
- max_rows_analyzed=150000, max_cols_analyzed=30
628
- ):
629
- tmp_dir = tempfile.mkdtemp()
630
- print(f"Bokeh temp directory: {tmp_dir}")
631
- try:
632
- matplotlib.use('Agg')
633
- plt.ioff()
634
- AV = AutoViz_Class()
635
- AV.AutoViz(
636
- filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
637
- verbose=verbose, lowess=lowess, chart_format="bokeh",
638
- max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
639
- save_plot_dir=tmp_dir
640
- )
641
- plt.close('all')
642
- print(f"Bokeh plots saved to: {tmp_dir}")
643
- return tmp_dir
644
- except Exception as e:
645
- print(f"Error in Bokeh: {e}")
646
- import traceback
647
- traceback.print_exc()
648
- raise
649
-
650
-
651
- # --- Fallback Matplotlib ---
652
- def generate_matplotlib_plots(df, user_id, base_filename):
653
- print("\nGenerating matplotlib fallback plots...")
654
- tmp_dir = tempfile.mkdtemp()
655
- plot_metadata = []
656
- try:
657
- matplotlib.use('Agg')
658
- numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
659
- if len(numeric_cols) == 0:
660
- print("No numeric columns for plotting")
661
- return []
662
-
663
- # Correlation heatmap
664
- if len(numeric_cols) > 1:
665
- plt.figure(figsize=(10, 8))
666
- import seaborn as sns
667
- corr = df[numeric_cols].corr()
668
- sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
669
- plt.title('Correlation Heatmap')
670
- plt.tight_layout()
671
- png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
672
- plt.savefig(png_path, dpi=100, bbox_inches='tight')
673
- plt.close()
674
-
675
- s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
676
- with open(png_path, 'rb') as f:
677
- s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
678
- plot_metadata.append({
679
- "file_name": "correlation_heatmap.png",
680
- "s3_path": s3_key,
681
- "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
682
- "type": "png",
683
- "plot_type": "heatmap"
684
- })
685
-
686
- # Histograms
687
- for col in numeric_cols[:5]:
688
- plt.figure(figsize=(10, 6))
689
- df[col].hist(bins=30, edgecolor='black')
690
- plt.title(f'Distribution of {col}')
691
- plt.xlabel(col)
692
- plt.ylabel('Frequency')
693
- plt.tight_layout()
694
- png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
695
- plt.savefig(png_path, dpi=100, bbox_inches='tight')
696
- plt.close()
697
-
698
- s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
699
- with open(png_path, 'rb') as f:
700
- s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
701
- plot_metadata.append({
702
- "file_name": f"distribution_{col}.png",
703
- "s3_path": s3_key,
704
- "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
705
- "type": "png",
706
- "plot_type": "histogram"
707
- })
708
- return plot_metadata
709
- except Exception as e:
710
- print(f"Matplotlib failed: {e}")
711
- return []
712
- finally:
713
- shutil.rmtree(tmp_dir, ignore_errors=True)
714
-
715
-
716
- # --- Upload Viz Files ---
717
- def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
718
- patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
719
- files = []
720
- for p in patterns:
721
- files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
722
- files = list(set(files))
723
- if not files:
724
- print(f"No {viz_type} files found")
725
- return []
726
-
727
- folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
728
- metadata = []
729
- content_type_map = {
730
- 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
731
- 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
732
- }
733
-
734
- for file_path in files:
735
- name = os.path.basename(file_path)
736
- ext = os.path.splitext(name)[1][1:]
737
- s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
738
- with open(file_path, "rb") as f:
739
- body = f.read()
740
- s3.put_object(
741
- Bucket=BUCKET_NAME, Key=s3_key, Body=body,
742
- ContentType=content_type_map.get(ext, 'application/octet-stream')
743
- )
744
- metadata.append({
745
- "file_name": name, "s3_path": s3_key,
746
- "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
747
- "type": viz_type, "size": len(body)
748
- })
749
- print(f"{viz_type.upper()} uploaded: {s3_key}")
750
- return metadata
751
-
752
-
753
- # --- Convert to Parquet ---
754
- def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
755
- buffer = io.BytesIO()
756
- df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
757
- buffer.seek(0)
758
- return buffer
759
-
760
-
761
- # --- Check File Hash Exists (FIXED) ---
762
- async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
763
- """
764
- Check if a file hash already exists for a user.
765
- Returns a dict with 'success', 'exists', and optional 'data' keys.
766
- """
767
- url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
768
- headers = {"accept": "application/json"}
769
- async with httpx.AsyncClient(timeout=10.0) as client:
770
- try:
771
- r = await client.get(url, headers=headers)
772
- r.raise_for_status()
773
- result = r.json()
774
-
775
- # Log the actual response for debugging
776
- print(f"Hash check API response: {result}")
777
-
778
- # Check for 'exists' field first, then parse message if not present
779
- exists = result.get("exists", None)
780
-
781
- # If 'exists' field is not present, parse the message
782
- if exists is None and "message" in result:
783
- message_lower = result["message"].lower()
784
- # Check if message indicates file already exists
785
- exists = "already existed" in message_lower or "already exists" in message_lower or "duplicate" in message_lower
786
-
787
- # Default to False if still None
788
- if exists is None:
789
- exists = False
790
-
791
- return {
792
- "success": True,
793
- "exists": exists,
794
- "data": result
795
- }
796
- except httpx.HTTPStatusError as e:
797
- print(f"Hash check HTTP error: {e.response.status_code} - {e.response.text}")
798
- return {
799
- "success": False,
800
- "exists": False, # Assume not exists on error to be safe
801
- "message": f"HTTP {e.response.status_code}",
802
- "error_detail": e.response.text
803
- }
804
- except Exception as e:
805
- print(f"Hash check exception: {str(e)}")
806
- return {
807
- "success": False,
808
- "exists": False, # Assume not exists on error
809
- "message": f"Request failed: {str(e)}"
810
- }
811
-
812
-
813
- # --- PostgreSQL Metadata Upload with file_hash ---
814
- async def user_metadata_upload_pg(
815
- user_id: str,
816
- user_metadata: str,
817
- path: str,
818
- url: str,
819
- filename: str,
820
- file_type: str,
821
- file_size_bytes: int,
822
- file_hash: str,
823
- timeout: float = 10.0
824
- ):
825
- payload = {
826
- "user_id": user_id,
827
- "user_metadata": user_metadata,
828
- "path": path,
829
- "url": url,
830
- "filename": filename,
831
- "file_type": file_type,
832
- "file_size_bytes": file_size_bytes,
833
- "file_hash": file_hash
834
- }
835
- print(f"PostgreSQL payload file_hash: {payload['file_hash']}")
836
- async with httpx.AsyncClient() as client:
837
- try:
838
- r = await client.post(
839
- "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
840
- json=payload,
841
- timeout=timeout
842
- )
843
- r.raise_for_status()
844
- result = r.json()
845
- result["file_hash"] = file_hash
846
- print(f"PostgreSQL result file_hash: {result['file_hash']}")
847
- return {"success": True, "data": result}
848
- except httpx.HTTPStatusError as e:
849
- return {
850
- "success": False,
851
- "error": "HTTP error",
852
- "status_code": e.response.status_code,
853
- "detail": e.response.text,
854
- "file_hash": file_hash
855
- }
856
- except Exception as e:
857
- return {
858
- "success": False,
859
- "error": "Request failed",
860
- "detail": str(e),
861
- "file_hash": file_hash
862
- }
863
-
864
-
865
- # --- DEBUG ENDPOINT (Optional - for troubleshooting) ---
866
- @s3_bucket_router1.get("/debug/check_hash/{user_id}/{file_hash}")
867
- async def debug_check_hash(user_id: str, file_hash: str):
868
- """Debug endpoint to test hash checking"""
869
- result = await check_file_hash_exists(user_id, file_hash)
870
- return {
871
- "raw_result": result,
872
- "exists_value": result.get("exists"),
873
- "exists_type": type(result.get("exists")).__name__,
874
- "success_value": result.get("success"),
875
- "interpretation": "File exists" if result.get("exists") is True else "File does not exist or check failed"
876
- }
877
-
878
-
879
- # --- MAIN ENDPOINT WITH FIXED HASH CHECK AND DATASET GRAPHS ---
880
- @s3_bucket_router1.post("/upload_datasets_v3/")
881
- async def upload_file(
882
- file: UploadFile = File(...),
883
- user_id: str = Query(..., description="User ID"),
884
- path: str = Query("", description="Optional subpath")
885
- ):
886
- html_tmp_dir = None
887
- bokeh_tmp_dir = None
888
- file_content = None
889
-
890
- try:
891
- # 1. Validate extension
892
- file_ext = os.path.splitext(file.filename)[1].lower()
893
- if file_ext not in ALLOWED_EXTENSIONS:
894
- raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_ext}")
895
-
896
- # 2. Read file
897
- file_content = await file.read()
898
- if not file_content:
899
- raise HTTPException(status_code=400, detail="Empty file")
900
- if len(file_content) > MAX_FILE_SIZE_BYTES:
901
- raise HTTPException(status_code=413, detail="File exceeds 100 MiB limit")
902
-
903
- # 3. Generate hash
904
- file_hash = hashlib.sha256(file_content).hexdigest()
905
- print(f"Generated file hash: {file_hash}")
906
-
907
- # 4. Check hash via API (FIXED LOGIC)
908
- hash_result = await check_file_hash_exists(user_id, file_hash)
909
-
910
- # Enhanced logging
911
- print(f"Hash check result: {hash_result}")
912
-
913
- # Check if the API call was successful first
914
- if not hash_result.get("success", False):
915
- print(f"⚠️ Warning: Hash check API failed: {hash_result.get('message')}")
916
- # You can decide to fail here if hash check is critical
917
- # raise HTTPException(status_code=503, detail="Hash check service unavailable")
918
-
919
- # Check if file exists
920
- if hash_result.get("exists") is True:
921
- print(f"🚫 Duplicate file detected: {file_hash}")
922
- return JSONResponse(
923
- status_code=409,
924
- content={
925
- "message": "File already uploaded.",
926
- "reason": "Duplicate file detected via SHA-256 hash.",
927
- "file_hash": file_hash,
928
- "user_id": user_id,
929
- "filename": file.filename,
930
- "action": "skipped",
931
- "existing_file_info": hash_result.get("data")
932
- }
933
- )
934
-
935
- print("✅ Hash check passed. New file - proceeding with upload.")
936
-
937
- # 5. Load DataFrame
938
- try:
939
- if file_ext == ".csv":
940
- df = pd.read_csv(io.BytesIO(file_content))
941
- elif file_ext in {".xlsx", ".xls"}:
942
- engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
943
- df = pd.read_excel(io.BytesIO(file_content), engine=engine)
944
- elif file_ext == ".ods":
945
- df = pd.read_excel(io.BytesIO(file_content), engine='odf')
946
- except Exception as e:
947
- raise HTTPException(status_code=400, detail=f"Failed to parse file: {str(e)}")
948
-
949
- if len(df) > MAX_ROWS_ALLOWED:
950
- raise HTTPException(status_code=413, detail=f"Too many rows: {len(df):,} > {MAX_ROWS_ALLOWED:,}")
951
-
952
- # 6. Convert to Parquet
953
- parquet_buffer = convert_df_to_parquet(df)
954
- parquet_size = parquet_buffer.getbuffer().nbytes
955
-
956
- # 7. Upload Parquet to S3
957
- base_filename = os.path.splitext(file.filename)[0]
958
- parquet_filename = f"{base_filename}.parquet"
959
- file_key = f"{user_id}/files/datasets/{parquet_filename}"
960
- file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
961
-
962
- s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
963
- ExtraArgs={'ContentType': 'application/octet-stream'})
964
- print(f"Uploaded Parquet: {file_url}")
965
-
966
- # 8. Generate metadata
967
- metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
968
- metadata.update({
969
- "user_id": user_id,
970
- "s3_path": file_key,
971
- "s3_url": file_url,
972
- "source_file": file.filename,
973
- "source_file_type": file_ext[1:],
974
- "file_type": "parquet",
975
- "original_file_size_bytes": len(file_content),
976
- "parquet_file_size_bytes": parquet_size,
977
- "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
978
- "file_hash": file_hash
979
- })
980
-
981
- # 8.5 Generate dataset preview graphs (separate from metadata)
982
- print("Generating dataset preview graphs...")
983
- dataset_graphs = None
984
- try:
985
- dataset_graphs = create_data_set_graphs_dict(df, max_rows=200)
986
- print(f"✅ Generated graphs for {len(dataset_graphs.get('columnSummaries', []))} columns")
987
- except Exception as e:
988
- print(f"⚠️ Failed to generate dataset graphs: {e}")
989
- import traceback
990
- traceback.print_exc()
991
- dataset_graphs = {
992
- "error": str(e),
993
- "columnSummaries": []
994
- }
995
-
996
- safe_metadata = sanitize_for_json(metadata)
997
- safe_dataset_graphs = sanitize_for_json(dataset_graphs) if dataset_graphs else {"columnSummaries": []}
998
-
999
- # 9. Vector DB
1000
- check_vdb(user_id)
1001
- vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
1002
- vdb_success = vdb_res.get("status") == "success"
1003
- print(f"VDB upload success: {vdb_success}")
1004
-
1005
- # 10. PostgreSQL Metadata + file_hash
1006
- pg_result = await user_metadata_upload_pg(
1007
- user_id=user_id,
1008
- user_metadata=json.dumps(safe_metadata),
1009
- path=file_key,
1010
- url=file_url,
1011
- filename=parquet_filename,
1012
- file_type="parquet",
1013
- file_size_bytes=parquet_size,
1014
- file_hash=file_hash
1015
- )
1016
- print(f"PostgreSQL upload result: {pg_result}")
1017
- pg_success = pg_result.get("success", False)
1018
-
1019
- graphs_count = len(safe_dataset_graphs.get("columnSummaries", []))
1020
- print(f"Graphs generated: {graphs_count}")
1021
-
1022
- # 11. Return success
1023
- return {
1024
- "message": "Upload successful.",
1025
- "filename": parquet_filename,
1026
- "original_filename": file.filename,
1027
- "user_id": user_id,
1028
- "file_path": file_key,
1029
- "file_url": file_url,
1030
- "file_hash": file_hash,
1031
- "source_file_type": file_ext[1:],
1032
- "file_type": "parquet",
1033
- "original_file_size_bytes": len(file_content),
1034
- "parquet_file_size_bytes": parquet_size,
1035
- "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
1036
- "rows": len(df),
1037
- "columns": len(df.columns),
1038
- "metadata": safe_metadata,
1039
- "data_sets_preview_graph": safe_dataset_graphs, # Separate key at root level
1040
- "upload_dataset_vdb": vdb_success,
1041
- "upload_dataset_pg": pg_success,
1042
- "pg_details": pg_result,
1043
- "graphs_generated": graphs_count
1044
- }
1045
-
1046
- except HTTPException:
1047
- raise
1048
- except Exception as e:
1049
- print(f"Unexpected error: {e}")
1050
- import traceback
1051
- traceback.print_exc()
1052
- raise HTTPException(status_code=500, detail=str(e))
1053
- finally:
1054
- # Clean up temp directories
1055
- for d in (html_tmp_dir, bokeh_tmp_dir):
1056
- if d and os.path.exists(d):
1057
  shutil.rmtree(d, ignore_errors=True)
 
1
+ # from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
2
+ # from fastapi.responses import JSONResponse
3
+ # import pandas as pd
4
+ # from autoviz.AutoViz_Class import AutoViz_Class
5
+ # import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
6
+ # matplotlib.use('Agg')
7
+ # import matplotlib.pyplot as plt
8
+ # import sys
9
+ # from pathlib import Path
10
+ # from typing import List
11
+
12
+ # # --- Project Root Setup ---
13
+ # PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
+ # if str(PROJECT_ROOT) not in sys.path:
15
+ # sys.path.insert(0, str(PROJECT_ROOT))
16
+
17
+ # from retrieve_secret import *
18
+ # import httpx
19
+
20
+ # # --- File Validation ---
21
+ # MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
22
+ # MAX_ROWS_ALLOWED = 1_000_000
23
+ # ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
24
+
25
+ # # --- AWS S3 Config ---
26
+ # print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
27
+ # ACCESS_KEY = AWS_S3_CREDS_KEY_ID
28
+ # SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
29
+ # BUCKET_NAME = BUCKET_NAME
30
+ # REGION_NAME = "us-east-1"
31
+
32
+ # s3 = boto3.client(
33
+ # "s3",
34
+ # aws_access_key_id=ACCESS_KEY,
35
+ # aws_secret_access_key=SECRET_KEY,
36
+ # region_name=REGION_NAME
37
+ # )
38
+
39
+ # ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
40
+
41
+ # # --- FastAPI Router ---
42
+ # s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
43
+
44
+
45
+ # # --- Helper: S3 Key ---
46
+ # def make_key(path: str, filename: str) -> str:
47
+ # return f"{path.strip('/')}/{filename}" if path else filename
48
+
49
+
50
+ # # --- Sanitize for JSON ---
51
+ # def sanitize_for_json(obj):
52
+ # if isinstance(obj, dict):
53
+ # return {k: sanitize_for_json(v) for k, v in obj.items()}
54
+ # elif isinstance(obj, list):
55
+ # return [sanitize_for_json(item) for item in obj]
56
+ # elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
57
+ # return str(obj)
58
+ # elif pd.isna(obj):
59
+ # return None
60
+ # elif isinstance(obj, (int, float, str, bool, type(None))):
61
+ # return obj
62
+ # else:
63
+ # return str(obj)
64
+
65
+
66
+ # # --- Read CSV from S3 (optional) ---
67
+ # def read_csv_from_s3(key: str) -> pd.DataFrame:
68
+ # obj = s3.get_object(Bucket=BUCKET_NAME, Key=key)
69
+ # return pd.read_csv(io.BytesIO(obj['Body'].read()))
70
+
71
+
72
+ # # --- Vector DB Placeholders ---
73
+ # def check_vdb(user_id: str):
74
+ # print(f"Checking VDB for user: {user_id}")
75
+
76
+ # async def add_metadata_only(collection_name: str, metadata: dict):
77
+ # print(f"Adding metadata to collection: {collection_name}")
78
+ # return {"status": "success", "collection": collection_name}
79
+
80
+
81
+ # # --- AutoViz HTML ---
82
+ # def run_autoviz_html(
83
+ # dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
84
+ # max_rows_analyzed=150000, max_cols_analyzed=30
85
+ # ):
86
+ # tmp_dir = tempfile.mkdtemp()
87
+ # AV = AutoViz_Class()
88
+ # AV.AutoViz(
89
+ # filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
90
+ # verbose=verbose, lowess=lowess, chart_format="html",
91
+ # max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
92
+ # save_plot_dir=tmp_dir
93
+ # )
94
+ # print(f"HTML plots saved to: {tmp_dir}")
95
+ # return tmp_dir
96
+
97
+
98
+ # # --- AutoViz Bokeh ---
99
+ # def run_autoviz_bokeh(
100
+ # dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
101
+ # max_rows_analyzed=150000, max_cols_analyzed=30
102
+ # ):
103
+ # tmp_dir = tempfile.mkdtemp()
104
+ # print(f"Bokeh temp directory: {tmp_dir}")
105
+ # try:
106
+ # matplotlib.use('Agg')
107
+ # plt.ioff()
108
+ # AV = AutoViz_Class()
109
+ # AV.AutoViz(
110
+ # filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
111
+ # verbose=verbose, lowess=lowess, chart_format="bokeh",
112
+ # max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
113
+ # save_plot_dir=tmp_dir
114
+ # )
115
+ # plt.close('all')
116
+ # print(f"Bokeh plots saved to: {tmp_dir}")
117
+ # return tmp_dir
118
+ # except Exception as e:
119
+ # print(f"Error in Bokeh: {e}")
120
+ # import traceback
121
+ # traceback.print_exc()
122
+ # raise
123
+
124
+
125
+ # # --- Fallback Matplotlib ---
126
+ # def generate_matplotlib_plots(df, user_id, base_filename):
127
+ # print("\nGenerating matplotlib fallback plots...")
128
+ # tmp_dir = tempfile.mkdtemp()
129
+ # plot_metadata = []
130
+ # try:
131
+ # matplotlib.use('Agg')
132
+ # numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
133
+ # if len(numeric_cols) == 0:
134
+ # print("No numeric columns for plotting")
135
+ # return []
136
+
137
+ # # Correlation heatmap
138
+ # if len(numeric_cols) > 1:
139
+ # plt.figure(figsize=(10, 8))
140
+ # import seaborn as sns
141
+ # corr = df[numeric_cols].corr()
142
+ # sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
143
+ # plt.title('Correlation Heatmap')
144
+ # plt.tight_layout()
145
+ # png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
146
+ # plt.savefig(png_path, dpi=100, bbox_inches='tight')
147
+ # plt.close()
148
+
149
+ # s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
150
+ # with open(png_path, 'rb') as f:
151
+ # s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
152
+ # plot_metadata.append({
153
+ # "file_name": "correlation_heatmap.png",
154
+ # "s3_path": s3_key,
155
+ # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
156
+ # "type": "png",
157
+ # "plot_type": "heatmap"
158
+ # })
159
+
160
+ # # Histograms
161
+ # for col in numeric_cols[:5]:
162
+ # plt.figure(figsize=(10, 6))
163
+ # df[col].hist(bins=30, edgecolor='black')
164
+ # plt.title(f'Distribution of {col}')
165
+ # plt.xlabel(col)
166
+ # plt.ylabel('Frequency')
167
+ # plt.tight_layout()
168
+ # png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
169
+ # plt.savefig(png_path, dpi=100, bbox_inches='tight')
170
+ # plt.close()
171
+
172
+ # s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
173
+ # with open(png_path, 'rb') as f:
174
+ # s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
175
+ # plot_metadata.append({
176
+ # "file_name": f"distribution_{col}.png",
177
+ # "s3_path": s3_key,
178
+ # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
179
+ # "type": "png",
180
+ # "plot_type": "histogram"
181
+ # })
182
+ # return plot_metadata
183
+ # except Exception as e:
184
+ # print(f"Matplotlib failed: {e}")
185
+ # return []
186
+ # finally:
187
+ # shutil.rmtree(tmp_dir, ignore_errors=True)
188
+
189
+
190
+ # # --- Upload Viz Files ---
191
+ # def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
192
+ # patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
193
+ # files = []
194
+ # for p in patterns:
195
+ # files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
196
+ # files = list(set(files))
197
+ # if not files:
198
+ # print(f"No {viz_type} files found")
199
+ # return []
200
+
201
+ # folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
202
+ # metadata = []
203
+ # content_type_map = {
204
+ # 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
205
+ # 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
206
+ # }
207
+
208
+ # for file_path in files:
209
+ # name = os.path.basename(file_path)
210
+ # ext = os.path.splitext(name)[1][1:]
211
+ # s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
212
+ # with open(file_path, "rb") as f:
213
+ # body = f.read()
214
+ # s3.put_object(
215
+ # Bucket=BUCKET_NAME, Key=s3_key, Body=body,
216
+ # ContentType=content_type_map.get(ext, 'application/octet-stream')
217
+ # )
218
+ # metadata.append({
219
+ # "file_name": name, "s3_path": s3_key,
220
+ # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
221
+ # "type": viz_type, "size": len(body)
222
+ # })
223
+ # print(f"{viz_type.upper()} uploaded: {s3_key}")
224
+ # return metadata
225
+
226
+
227
+ # # --- Convert to Parquet ---
228
+ # def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
229
+ # buffer = io.BytesIO()
230
+ # df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
231
+ # buffer.seek(0)
232
+ # return buffer
233
+
234
+
235
+ # # --- NEW: Check File Hash ---
236
+ # async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
237
+ # url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
238
+ # headers = {"accept": "application/json"}
239
+ # async with httpx.AsyncClient(timeout=10.0) as client:
240
+ # try:
241
+ # r = await client.get(url, headers=headers)
242
+ # r.raise_for_status()
243
+ # return r.json()
244
+ # except httpx.HTTPStatusError as e:
245
+ # return {"success": False, "message": f"HTTP {e.response.status_code}"}
246
+ # except Exception as e:
247
+ # return {"success": False, "message": f"Request failed: {str(e)}"}
248
+
249
+
250
+ # # --- PostgreSQL Metadata Upload ---
251
+ # async def user_metadata_upload_pg(
252
+ # user_id: str, user_metadata: str, path: str, url: str,
253
+ # filename: str, file_type: str, file_size_bytes: int, timeout: float = 10.0
254
+ # ):
255
+ # payload = {
256
+ # "user_id": user_id, "user_metadata": user_metadata, "path": path,
257
+ # "url": url, "filename": filename, "file_type": file_type,
258
+ # "file_size_bytes": file_size_bytes
259
+ # }
260
+ # async with httpx.AsyncClient() as client:
261
+ # try:
262
+ # r = await client.post(
263
+ # "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
264
+ # json=payload, timeout=timeout
265
+ # )
266
+ # r.raise_for_status()
267
+ # return {"success": True, "data": r.json()}
268
+ # except httpx.HTTPStatusError as e:
269
+ # return {"success": False, "error": "HTTP error", "status_code": e.response.status_code, "detail": e.response.text}
270
+ # except Exception as e:
271
+ # return {"success": False, "error": "Request failed", "detail": str(e)}
272
+
273
+
274
+ # # --- MAIN ENDPOINT ---
275
+ # @s3_bucket_router1.post("/upload_datasets_v3/")
276
+ # async def upload_file(
277
+ # file: UploadFile = File(...),
278
+ # user_id: str = Query(..., description="User ID"),
279
+ # path: str = Query("", description="Optional subpath")
280
+ # ):
281
+ # html_tmp_dir = bokeh_tmp_dir = None
282
+ # file_content = None
283
+
284
+ # try:
285
+ # # 1. Validate extension
286
+ # file_ext = os.path.splitext(file.filename)[1].lower()
287
+ # if file_ext not in ALLOWED_EXTENSIONS:
288
+ # raise HTTPException(status_code=400, detail=f"Unsupported: {file_ext}")
289
+
290
+ # # 2. Read file
291
+ # file_content = await file.read()
292
+ # if not file_content:
293
+ # raise HTTPException(status_code=400, detail="Empty file")
294
+ # if len(file_content) > MAX_FILE_SIZE_BYTES:
295
+ # raise HTTPException(status_code=413, detail=f"File > 100 MiB")
296
+
297
+ # # 3. Generate hash
298
+ # file_hash = hashlib.sha256(file_content).hexdigest()
299
+ # print(f"Generated hash: {file_hash}")
300
+
301
+ # # 4. Check hash via API
302
+ # hash_result = await check_file_hash_exists(user_id, file_hash)
303
+ # if not hash_result.get("success", False):
304
+ # return JSONResponse(
305
+ # status_code=409,
306
+ # content={
307
+ # "message": "File already uploaded.",
308
+ # "reason": hash_result.get("message", "Hash exists."),
309
+ # "file_hash": file_hash,
310
+ # "user_id": user_id,
311
+ # "action": "skipped"
312
+ # }
313
+ # )
314
+ # print("Hash check passed.")
315
+
316
+ # # 5. Load DataFrame
317
+ # try:
318
+ # if file_ext == ".csv":
319
+ # df = pd.read_csv(io.BytesIO(file_content))
320
+ # elif file_ext in {".xlsx", ".xls"}:
321
+ # engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
322
+ # df = pd.read_excel(io.BytesIO(file_content), engine=engine)
323
+ # elif file_ext == ".ods":
324
+ # df = pd.read_excel(io.BytesIO(file_content), engine='odf')
325
+ # except Exception as e:
326
+ # raise HTTPException(status_code=400, detail=f"Parse error: {e}")
327
+
328
+ # if len(df) > MAX_ROWS_ALLOWED:
329
+ # raise HTTPException(status_code=413, detail=f"Rows > {MAX_ROWS_ALLOWED:,}")
330
+
331
+ # # 6. Convert to Parquet
332
+ # parquet_buffer = convert_df_to_parquet(df)
333
+ # parquet_size = parquet_buffer.getbuffer().nbytes
334
+
335
+ # # 7. Upload to S3
336
+ # base_filename = os.path.splitext(file.filename)[0]
337
+ # parquet_filename = f"{base_filename}.parquet"
338
+ # file_key = f"{user_id}/files/datasets/{parquet_filename}"
339
+ # file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
340
+
341
+ # s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
342
+ # ExtraArgs={'ContentType': 'application/octet-stream'})
343
+ # print(f"Uploaded: {file_url}")
344
+
345
+ # # 8. Metadata
346
+ # from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
347
+ # metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
348
+ # metadata.update({
349
+ # "user_id": user_id, "s3_path": file_key, "s3_url": file_url,
350
+ # "source_file": file.filename, "source_file_type": file_ext[1:],
351
+ # "file_type": "parquet", "original_file_size_bytes": len(file_content),
352
+ # "parquet_file_size_bytes": parquet_size,
353
+ # "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
354
+ # "file_hash": file_hash
355
+ # })
356
+ # print(f"Metadata: {metadata}")
357
+ # # 9. Vector DB
358
+ # check_vdb(user_id)
359
+ # vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
360
+ # vdb_success = vdb_res.get("status") == "success"
361
+ # #safe_metadata
362
+ # safe_metadata = sanitize_for_json(metadata)
363
+ # # 10. PostgreSQL
364
+ # # pg_result = await user_metadata_upload_pg(
365
+ # # user_id=user_id, user_metadata=json.dumps(metadata),
366
+ # # path=file_key, url=file_url, filename=parquet_filename,
367
+ # # file_type="parquet", file_size_bytes=parquet_size
368
+ # # )
369
+ # pg_result = await user_metadata_upload_pg(
370
+ # user_id=user_id, user_metadata=json.dumps(safe_metadata),
371
+ # path=file_key, url=file_url, filename=parquet_filename,
372
+ # file_type="parquet", file_size_bytes=parquet_size
373
+ # )
374
+ # # Inject hash into pg_result
375
+ # if isinstance(pg_result, dict):
376
+ # pg_result["file_hash"] = file_hash
377
+ # pg_success = pg_result.get("success", False)
378
+
379
+ # # 11. Return
380
+ # return {
381
+ # "message": "Upload successful.",
382
+ # "filename": parquet_filename,
383
+ # "original_filename": file.filename,
384
+ # "user_id": user_id,
385
+ # "file_path": file_key,
386
+ # "file_url": file_url,
387
+ # "file_hash": file_hash,
388
+ # "source_file_type": file_ext[1:],
389
+ # "file_type": "parquet",
390
+ # "original_file_size_bytes": len(file_content),
391
+ # "parquet_file_size_bytes": parquet_size,
392
+ # "rows": len(df), "columns": len(df.columns),
393
+ # # "metadata": sanitize_for_json(metadata),
394
+ # "metadata": safe_metadata,
395
+ # "upload_dataset_vdb": vdb_success,
396
+ # "upload_dataset_pg": pg_success,
397
+ # "pg_details": pg_result
398
+ # }
399
+
400
+ # except HTTPException:
401
+ # raise
402
+ # except Exception as e:
403
+ # print(f"Error: {e}")
404
+ # import traceback; traceback.print_exc()
405
+ # raise HTTPException(status_code=500, detail=str(e))
406
+ # finally:
407
+ # for d in (html_tmp_dir, bokeh_tmp_dir):
408
+ # if d and os.path.exists(d):
409
+ # shutil.rmtree(d, ignore_errors=True)
410
+
411
+
412
+
413
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
414
+ from fastapi.responses import JSONResponse
415
+ import pandas as pd
416
+ from autoviz.AutoViz_Class import AutoViz_Class
417
+ import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
418
+ matplotlib.use('Agg')
419
+ import matplotlib.pyplot as plt
420
+ import sys
421
+ from pathlib import Path
422
+ from typing import List
423
+ import httpx
424
+
425
+ # --- Project Root Setup ---
426
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
427
+ if str(PROJECT_ROOT) not in sys.path:
428
+ sys.path.insert(0, str(PROJECT_ROOT))
429
+
430
+ from retrieve_secret import *
431
+ from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
432
+
433
+ # --- File Validation ---
434
+ MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
435
+ MAX_ROWS_ALLOWED = 1_000_000
436
+ ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
437
+
438
+ # --- AWS S3 Config ---
439
+ print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
440
+ ACCESS_KEY = AWS_S3_CREDS_KEY_ID
441
+ SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
442
+ BUCKET_NAME = BUCKET_NAME
443
+ REGION_NAME = "us-east-1"
444
+
445
+ s3 = boto3.client(
446
+ "s3",
447
+ aws_access_key_id=ACCESS_KEY,
448
+ aws_secret_access_key=SECRET_KEY,
449
+ region_name=REGION_NAME
450
+ )
451
+
452
+ ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
453
+
454
+ # --- FastAPI Router ---
455
+ s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
456
+
457
+
458
+ # --- Helper: S3 Key ---
459
+ def make_key(path: str, filename: str) -> str:
460
+ return f"{path.strip('/')}/{filename}" if path else filename
461
+
462
+
463
+ # --- Sanitize for JSON ---
464
+ def sanitize_for_json(obj):
465
+ if isinstance(obj, dict):
466
+ return {k: sanitize_for_json(v) for k, v in obj.items()}
467
+ elif isinstance(obj, list):
468
+ return [sanitize_for_json(item) for item in obj]
469
+ elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
470
+ return str(obj)
471
+ elif pd.isna(obj):
472
+ return None
473
+ elif isinstance(obj, (int, float, str, bool, type(None))):
474
+ return obj
475
+ else:
476
+ return str(obj)
477
+
478
+
479
+ # --- Vector DB Placeholders ---
480
+ def check_vdb(user_id: str):
481
+ print(f"Checking VDB for user: {user_id}")
482
+
483
+ async def add_metadata_only(collection_name: str, metadata: dict):
484
+ print(f"Adding metadata to collection: {collection_name}")
485
+ return {"status": "success", "collection": collection_name}
486
+
487
+
488
+ # --- AutoViz HTML ---
489
+ def run_autoviz_html(
490
+ dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
491
+ max_rows_analyzed=150000, max_cols_analyzed=30
492
+ ):
493
+ tmp_dir = tempfile.mkdtemp()
494
+ AV = AutoViz_Class()
495
+ AV.AutoViz(
496
+ filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
497
+ verbose=verbose, lowess=lowess, chart_format="html",
498
+ max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
499
+ save_plot_dir=tmp_dir
500
+ )
501
+ print(f"HTML plots saved to: {tmp_dir}")
502
+ return tmp_dir
503
+
504
+
505
+ # --- AutoViz Bokeh ---
506
+ def run_autoviz_bokeh(
507
+ dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
508
+ max_rows_analyzed=150000, max_cols_analyzed=30
509
+ ):
510
+ tmp_dir = tempfile.mkdtemp()
511
+ print(f"Bokeh temp directory: {tmp_dir}")
512
+ try:
513
+ matplotlib.use('Agg')
514
+ plt.ioff()
515
+ AV = AutoViz_Class()
516
+ AV.AutoViz(
517
+ filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
518
+ verbose=verbose, lowess=lowess, chart_format="bokeh",
519
+ max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
520
+ save_plot_dir=tmp_dir
521
+ )
522
+ plt.close('all')
523
+ print(f"Bokeh plots saved to: {tmp_dir}")
524
+ return tmp_dir
525
+ except Exception as e:
526
+ print(f"Error in Bokeh: {e}")
527
+ import traceback
528
+ traceback.print_exc()
529
+ raise
530
+
531
+
532
+ # --- Fallback Matplotlib ---
533
+ def generate_matplotlib_plots(df, user_id, base_filename):
534
+ print("\nGenerating matplotlib fallback plots...")
535
+ tmp_dir = tempfile.mkdtemp()
536
+ plot_metadata = []
537
+ try:
538
+ matplotlib.use('Agg')
539
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
540
+ if len(numeric_cols) == 0:
541
+ print("No numeric columns for plotting")
542
+ return []
543
+
544
+ # Correlation heatmap
545
+ if len(numeric_cols) > 1:
546
+ plt.figure(figsize=(10, 8))
547
+ import seaborn as sns
548
+ corr = df[numeric_cols].corr()
549
+ sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
550
+ plt.title('Correlation Heatmap')
551
+ plt.tight_layout()
552
+ png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
553
+ plt.savefig(png_path, dpi=100, bbox_inches='tight')
554
+ plt.close()
555
+
556
+ s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
557
+ with open(png_path, 'rb') as f:
558
+ s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
559
+ plot_metadata.append({
560
+ "file_name": "correlation_heatmap.png",
561
+ "s3_path": s3_key,
562
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
563
+ "type": "png",
564
+ "plot_type": "heatmap"
565
+ })
566
+
567
+ # Histograms
568
+ for col in numeric_cols[:5]:
569
+ plt.figure(figsize=(10, 6))
570
+ df[col].hist(bins=30, edgecolor='black')
571
+ plt.title(f'Distribution of {col}')
572
+ plt.xlabel(col)
573
+ plt.ylabel('Frequency')
574
+ plt.tight_layout()
575
+ png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
576
+ plt.savefig(png_path, dpi=100, bbox_inches='tight')
577
+ plt.close()
578
+
579
+ s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
580
+ with open(png_path, 'rb') as f:
581
+ s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
582
+ plot_metadata.append({
583
+ "file_name": f"distribution_{col}.png",
584
+ "s3_path": s3_key,
585
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
586
+ "type": "png",
587
+ "plot_type": "histogram"
588
+ })
589
+ return plot_metadata
590
+ except Exception as e:
591
+ print(f"Matplotlib failed: {e}")
592
+ return []
593
+ finally:
594
+ shutil.rmtree(tmp_dir, ignore_errors=True)
595
+
596
+
597
+ # --- Upload Viz Files ---
598
+ def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
599
+ patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
600
+ files = []
601
+ for p in patterns:
602
+ files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
603
+ files = list(set(files))
604
+ if not files:
605
+ print(f"No {viz_type} files found")
606
+ return []
607
+
608
+ folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
609
+ metadata = []
610
+ content_type_map = {
611
+ 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
612
+ 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
613
+ }
614
+
615
+ for file_path in files:
616
+ name = os.path.basename(file_path)
617
+ ext = os.path.splitext(name)[1][1:]
618
+ s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
619
+ with open(file_path, "rb") as f:
620
+ body = f.read()
621
+ s3.put_object(
622
+ Bucket=BUCKET_NAME, Key=s3_key, Body=body,
623
+ ContentType=content_type_map.get(ext, 'application/octet-stream')
624
+ )
625
+ metadata.append({
626
+ "file_name": name, "s3_path": s3_key,
627
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
628
+ "type": viz_type, "size": len(body)
629
+ })
630
+ print(f"{viz_type.upper()} uploaded: {s3_key}")
631
+ return metadata
632
+
633
+
634
+ # --- Convert to Parquet ---
635
+ def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
636
+ buffer = io.BytesIO()
637
+ df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
638
+ buffer.seek(0)
639
+ return buffer
640
+
641
+
642
+ # --- Check File Hash Exists ---
643
+ async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
644
+ url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
645
+ headers = {"accept": "application/json"}
646
+ async with httpx.AsyncClient(timeout=10.0) as client:
647
+ try:
648
+ r = await client.get(url, headers=headers)
649
+ r.raise_for_status()
650
+ return r.json()
651
+ except httpx.HTTPStatusError as e:
652
+ return {"success": False, "message": f"HTTP {e.response.status_code}"}
653
+ except Exception as e:
654
+ return {"success": False, "message": f"Request failed: {str(e)}"}
655
+
656
+
657
+ # --- FIXED: PostgreSQL Metadata Upload with file_hash ---
658
+ async def user_metadata_upload_pg(
659
+ user_id: str,
660
+ user_metadata: str,
661
+ path: str,
662
+ url: str,
663
+ filename: str,
664
+ file_type: str,
665
+ file_size_bytes: int,
666
+ file_hash: str, # Now accepted
667
+ timeout: float = 10.0
668
+ ):
669
+ payload = {
670
+ "user_id": user_id,
671
+ "user_metadata": user_metadata,
672
+ "path": path,
673
+ "url": url,
674
+ "filename": filename,
675
+ "file_type": file_type,
676
+ "file_size_bytes": file_size_bytes,
677
+ "file_hash": file_hash # This goes into DB now!
678
+ }
679
+ print("payload[file_hash]", payload["file_hash"])
680
+ async with httpx.AsyncClient() as client:
681
+
682
+ try:
683
+
684
+ r = await client.post(
685
+ "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
686
+ json=payload,
687
+ timeout=timeout
688
+ )
689
+ r.raise_for_status()
690
+ result = r.json()
691
+ result["file_hash"] = file_hash
692
+ print("result[file_hash]", result["file_hash"])
693
+ return {"success": True, "data": result}
694
+ except httpx.HTTPStatusError as e:
695
+ return {
696
+ "success": False,
697
+ "error": "HTTP error",
698
+ "status_code": e.response.status_code,
699
+ "detail": e.response.text,
700
+ "file_hash": file_hash
701
+ }
702
+ except Exception as e:
703
+ return {
704
+ "success": False,
705
+ "error": "Request failed",
706
+ "detail": str(e),
707
+ "file_hash": file_hash
708
+ }
709
+
710
+
711
+ # --- MAIN ENDPOINT: FULLY FIXED ---
712
+ @s3_bucket_router1.post("/upload_datasets_v3/")
713
+ async def upload_file(
714
+ file: UploadFile = File(...),
715
+ user_id: str = Query(..., description="User ID"),
716
+ path: str = Query("", description="Optional subpath")
717
+ ):
718
+ html_tmp_dir = None
719
+ bokeh_tmp_dir = None
720
+ file_content = None
721
+
722
+ try:
723
+ # 1. Validate extension
724
+ file_ext = os.path.splitext(file.filename)[1].lower()
725
+ if file_ext not in ALLOWED_EXTENSIONS:
726
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_ext}")
727
+
728
+ # 2. Read file
729
+ file_content = await file.read()
730
+ if not file_content:
731
+ raise HTTPException(status_code=400, detail="Empty file")
732
+ if len(file_content) > MAX_FILE_SIZE_BYTES:
733
+ raise HTTPException(status_code=413, detail="File exceeds 100 MiB limit")
734
+
735
+ # 3. Generate hash
736
+ file_hash = hashlib.sha256(file_content).hexdigest()
737
+ print(f"Generated hash: {file_hash}")
738
+
739
+ # 4. Check hash via API
740
+ hash_result = await check_file_hash_exists(user_id, file_hash)
741
+ if hash_result.get("exists") is True:
742
+ return JSONResponse(
743
+ status_code=409,
744
+ content={
745
+ "message": "File already uploaded.",
746
+ "reason": "Duplicate file detected via SHA-256 hash.",
747
+ "file_hash": file_hash,
748
+ "user_id": user_id,
749
+ "action": "skipped"
750
+ }
751
+ )
752
+ print("Hash check passed. New file.")
753
+
754
+ # 5. Load DataFrame
755
+ try:
756
+ if file_ext == ".csv":
757
+ df = pd.read_csv(io.BytesIO(file_content))
758
+ elif file_ext in {".xlsx", ".xls"}:
759
+ engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
760
+ df = pd.read_excel(io.BytesIO(file_content), engine=engine)
761
+ elif file_ext == ".ods":
762
+ df = pd.read_excel(io.BytesIO(file_content), engine='odf')
763
+ except Exception as e:
764
+ raise HTTPException(status_code=400, detail=f"Failed to parse file: {str(e)}")
765
+
766
+ if len(df) > MAX_ROWS_ALLOWED:
767
+ raise HTTPException(status_code=413, detail=f"Too many rows: {len(df):,} > {MAX_ROWS_ALLOWED:,}")
768
+
769
+ # 6. Convert to Parquet
770
+ parquet_buffer = convert_df_to_parquet(df)
771
+ parquet_size = parquet_buffer.getbuffer().nbytes
772
+
773
+ # 7. Upload Parquet to S3
774
+ base_filename = os.path.splitext(file.filename)[0]
775
+ parquet_filename = f"{base_filename}.parquet"
776
+ file_key = f"{user_id}/files/datasets/{parquet_filename}"
777
+ file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
778
+
779
+ s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
780
+ ExtraArgs={'ContentType': 'application/octet-stream'})
781
+ print(f"Uploaded Parquet: {file_url}")
782
+
783
+ # 8. Generate metadata
784
+ metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
785
+ metadata.update({
786
+ "user_id": user_id,
787
+ "s3_path": file_key,
788
+ "s3_url": file_url,
789
+ "source_file": file.filename,
790
+ "source_file_type": file_ext[1:],
791
+ "file_type": "parquet",
792
+ "original_file_size_bytes": len(file_content),
793
+ "parquet_file_size_bytes": parquet_size,
794
+ "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
795
+ "file_hash": file_hash
796
+ })
797
+
798
+ safe_metadata = sanitize_for_json(metadata)
799
+
800
+ # 9. Vector DB
801
+ check_vdb(user_id)
802
+ vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
803
+ vdb_success = vdb_res.get("status") == "success"
804
+ print(f"vdb_success: {vdb_success}")
805
+
806
+ # 10. PostgreSQL Metadata + file_hash (FIXED!)
807
+ pg_result = await user_metadata_upload_pg(
808
+ user_id=user_id,
809
+ user_metadata=json.dumps(safe_metadata),
810
+ path=file_key,
811
+ url=file_url,
812
+ filename=parquet_filename,
813
+ file_type="parquet",
814
+ file_size_bytes=parquet_size,
815
+ file_hash=file_hash # This was missing before!
816
+ )
817
+ print(f"pg_result: {pg_result}")
818
+ pg_success = pg_result.get("success", False)
819
+
820
+ # 11. Return success
821
+ return {
822
+ "message": "Upload successful.",
823
+ "filename": parquet_filename,
824
+ "original_filename": file.filename,
825
+ "user_id": user_id,
826
+ "file_path": file_key,
827
+ "file_url": file_url,
828
+ "file_hash": file_hash,
829
+ "source_file_type": file_ext[1:],
830
+ "file_type": "parquet",
831
+ "original_file_size_bytes": len(file_content),
832
+ "parquet_file_size_bytes": parquet_size,
833
+ "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
834
+ "rows": len(df),
835
+ "columns": len(df.columns),
836
+ "metadata": safe_metadata,
837
+ "upload_dataset_vdb": vdb_success,
838
+ "upload_dataset_pg": pg_success,
839
+ "pg_details": pg_result
840
+ }
841
+
842
+ except HTTPException:
843
+ raise
844
+ except Exception as e:
845
+ print(f"Unexpected error: {e}")
846
+ import traceback
847
+ traceback.print_exc()
848
+ raise HTTPException(status_code=500, detail=str(e))
849
+ finally:
850
+ # Clean up temp directories even on error
851
+ for d in (html_tmp_dir, bokeh_tmp_dir):
852
+ if d and os.path.exists(d):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  shutil.rmtree(d, ignore_errors=True)
s3/r5.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
2
+ from fastapi.responses import JSONResponse
3
+ import pandas as pd
4
+ from autoviz.AutoViz_Class import AutoViz_Class
5
+ import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
6
+ matplotlib.use('Agg')
7
+ import matplotlib.pyplot as plt
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import List
11
+ import httpx
12
+
13
+ # --- Project Root Setup ---
14
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
15
+ if str(PROJECT_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(PROJECT_ROOT))
17
+
18
+ from retrieve_secret import *
19
+ from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
20
+ from s3.create_dataset_graphs import create_data_set_graphs_dict
21
+
22
+ # --- File Validation ---
23
+ MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
24
+ MAX_ROWS_ALLOWED = 1_000_000
25
+ ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
26
+
27
+ # --- AWS S3 Config ---
28
+ print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
29
+ ACCESS_KEY = AWS_S3_CREDS_KEY_ID
30
+ SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
31
+ BUCKET_NAME = BUCKET_NAME
32
+ REGION_NAME = "us-east-1"
33
+
34
+ s3 = boto3.client(
35
+ "s3",
36
+ aws_access_key_id=ACCESS_KEY,
37
+ aws_secret_access_key=SECRET_KEY,
38
+ region_name=REGION_NAME
39
+ )
40
+
41
+ ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
42
+
43
+ # --- FastAPI Router ---
44
+ s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
45
+
46
+
47
+ # --- Helper: S3 Key ---
48
+ def make_key(path: str, filename: str) -> str:
49
+ return f"{path.strip('/')}/{filename}" if path else filename
50
+
51
+
52
+ # --- Sanitize for JSON ---
53
+ def sanitize_for_json(obj):
54
+ if isinstance(obj, dict):
55
+ return {k: sanitize_for_json(v) for k, v in obj.items()}
56
+ elif isinstance(obj, list):
57
+ return [sanitize_for_json(item) for item in obj]
58
+ elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
59
+ return str(obj)
60
+ elif pd.isna(obj):
61
+ return None
62
+ elif isinstance(obj, (int, float, str, bool, type(None))):
63
+ return obj
64
+ else:
65
+ return str(obj)
66
+
67
+
68
+ # --- Vector DB Placeholders ---
69
+ def check_vdb(user_id: str):
70
+ print(f"Checking VDB for user: {user_id}")
71
+
72
+ async def add_metadata_only(collection_name: str, metadata: dict):
73
+ print(f"Adding metadata to collection: {collection_name}")
74
+ return {"status": "success", "collection": collection_name}
75
+
76
+
77
+ # --- AutoViz HTML ---
78
+ def run_autoviz_html(
79
+ dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
80
+ max_rows_analyzed=150000, max_cols_analyzed=30
81
+ ):
82
+ tmp_dir = tempfile.mkdtemp()
83
+ AV = AutoViz_Class()
84
+ AV.AutoViz(
85
+ filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
86
+ verbose=verbose, lowess=lowess, chart_format="html",
87
+ max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
88
+ save_plot_dir=tmp_dir
89
+ )
90
+ print(f"HTML plots saved to: {tmp_dir}")
91
+ return tmp_dir
92
+
93
+
94
+ # --- AutoViz Bokeh ---
95
+ def run_autoviz_bokeh(
96
+ dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
97
+ max_rows_analyzed=150000, max_cols_analyzed=30
98
+ ):
99
+ tmp_dir = tempfile.mkdtemp()
100
+ print(f"Bokeh temp directory: {tmp_dir}")
101
+ try:
102
+ matplotlib.use('Agg')
103
+ plt.ioff()
104
+ AV = AutoViz_Class()
105
+ AV.AutoViz(
106
+ filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
107
+ verbose=verbose, lowess=lowess, chart_format="bokeh",
108
+ max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
109
+ save_plot_dir=tmp_dir
110
+ )
111
+ plt.close('all')
112
+ print(f"Bokeh plots saved to: {tmp_dir}")
113
+ return tmp_dir
114
+ except Exception as e:
115
+ print(f"Error in Bokeh: {e}")
116
+ import traceback
117
+ traceback.print_exc()
118
+ raise
119
+
120
+
121
+ # --- Fallback Matplotlib ---
122
+ def generate_matplotlib_plots(df, user_id, base_filename):
123
+ print("\nGenerating matplotlib fallback plots...")
124
+ tmp_dir = tempfile.mkdtemp()
125
+ plot_metadata = []
126
+ try:
127
+ matplotlib.use('Agg')
128
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
129
+ if len(numeric_cols) == 0:
130
+ print("No numeric columns for plotting")
131
+ return []
132
+
133
+ # Correlation heatmap
134
+ if len(numeric_cols) > 1:
135
+ plt.figure(figsize=(10, 8))
136
+ import seaborn as sns
137
+ corr = df[numeric_cols].corr()
138
+ sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
139
+ plt.title('Correlation Heatmap')
140
+ plt.tight_layout()
141
+ png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
142
+ plt.savefig(png_path, dpi=100, bbox_inches='tight')
143
+ plt.close()
144
+
145
+ s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
146
+ with open(png_path, 'rb') as f:
147
+ s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
148
+ plot_metadata.append({
149
+ "file_name": "correlation_heatmap.png",
150
+ "s3_path": s3_key,
151
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
152
+ "type": "png",
153
+ "plot_type": "heatmap"
154
+ })
155
+
156
+ # Histograms
157
+ for col in numeric_cols[:5]:
158
+ plt.figure(figsize=(10, 6))
159
+ df[col].hist(bins=30, edgecolor='black')
160
+ plt.title(f'Distribution of {col}')
161
+ plt.xlabel(col)
162
+ plt.ylabel('Frequency')
163
+ plt.tight_layout()
164
+ png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
165
+ plt.savefig(png_path, dpi=100, bbox_inches='tight')
166
+ plt.close()
167
+
168
+ s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
169
+ with open(png_path, 'rb') as f:
170
+ s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
171
+ plot_metadata.append({
172
+ "file_name": f"distribution_{col}.png",
173
+ "s3_path": s3_key,
174
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
175
+ "type": "png",
176
+ "plot_type": "histogram"
177
+ })
178
+ return plot_metadata
179
+ except Exception as e:
180
+ print(f"Matplotlib failed: {e}")
181
+ return []
182
+ finally:
183
+ shutil.rmtree(tmp_dir, ignore_errors=True)
184
+
185
+
186
+ # --- Upload Viz Files ---
187
+ def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
188
+ patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
189
+ files = []
190
+ for p in patterns:
191
+ files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
192
+ files = list(set(files))
193
+ if not files:
194
+ print(f"No {viz_type} files found")
195
+ return []
196
+
197
+ folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
198
+ metadata = []
199
+ content_type_map = {
200
+ 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
201
+ 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
202
+ }
203
+
204
+ for file_path in files:
205
+ name = os.path.basename(file_path)
206
+ ext = os.path.splitext(name)[1][1:]
207
+ s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
208
+ with open(file_path, "rb") as f:
209
+ body = f.read()
210
+ s3.put_object(
211
+ Bucket=BUCKET_NAME, Key=s3_key, Body=body,
212
+ ContentType=content_type_map.get(ext, 'application/octet-stream')
213
+ )
214
+ metadata.append({
215
+ "file_name": name, "s3_path": s3_key,
216
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
217
+ "type": viz_type, "size": len(body)
218
+ })
219
+ print(f"{viz_type.upper()} uploaded: {s3_key}")
220
+ return metadata
221
+
222
+
223
+ # --- Convert to Parquet ---
224
+ def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
225
+ buffer = io.BytesIO()
226
+ df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
227
+ buffer.seek(0)
228
+ return buffer
229
+
230
+
231
+ # --- Check File Hash Exists ---
232
+ async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
233
+ url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
234
+ headers = {"accept": "application/json"}
235
+ async with httpx.AsyncClient(timeout=10.0) as client:
236
+ try:
237
+ r = await client.get(url, headers=headers)
238
+ r.raise_for_status()
239
+ return r.json()
240
+ except httpx.HTTPStatusError as e:
241
+ return {"success": False, "message": f"HTTP {e.response.status_code}"}
242
+ except Exception as e:
243
+ return {"success": False, "message": f"Request failed: {str(e)}"}
244
+
245
+
246
+ # --- PostgreSQL Metadata Upload with file_hash ---
247
+ async def user_metadata_upload_pg(
248
+ user_id: str,
249
+ user_metadata: str,
250
+ path: str,
251
+ url: str,
252
+ filename: str,
253
+ file_type: str,
254
+ file_size_bytes: int,
255
+ file_hash: str,
256
+ timeout: float = 10.0
257
+ ):
258
+ payload = {
259
+ "user_id": user_id,
260
+ "user_metadata": user_metadata,
261
+ "path": path,
262
+ "url": url,
263
+ "filename": filename,
264
+ "file_type": file_type,
265
+ "file_size_bytes": file_size_bytes,
266
+ "file_hash": file_hash
267
+ }
268
+ print("payload[file_hash]", payload["file_hash"])
269
+ async with httpx.AsyncClient() as client:
270
+ try:
271
+ r = await client.post(
272
+ "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
273
+ json=payload,
274
+ timeout=timeout
275
+ )
276
+ r.raise_for_status()
277
+ result = r.json()
278
+ result["file_hash"] = file_hash
279
+ print("result[file_hash]", result["file_hash"])
280
+ return {"success": True, "data": result}
281
+ except httpx.HTTPStatusError as e:
282
+ return {
283
+ "success": False,
284
+ "error": "HTTP error",
285
+ "status_code": e.response.status_code,
286
+ "detail": e.response.text,
287
+ "file_hash": file_hash
288
+ }
289
+ except Exception as e:
290
+ return {
291
+ "success": False,
292
+ "error": "Request failed",
293
+ "detail": str(e),
294
+ "file_hash": file_hash
295
+ }
296
+
297
+
298
+ # --- MAIN ENDPOINT WITH DATASET GRAPHS ---
299
+ @s3_bucket_router1.post("/upload_datasets_v3/")
300
+ async def upload_file(
301
+ file: UploadFile = File(...),
302
+ user_id: str = Query(..., description="User ID"),
303
+ path: str = Query("", description="Optional subpath")
304
+ ):
305
+ html_tmp_dir = None
306
+ bokeh_tmp_dir = None
307
+ file_content = None
308
+
309
+ try:
310
+ # 1. Validate extension
311
+ file_ext = os.path.splitext(file.filename)[1].lower()
312
+ if file_ext not in ALLOWED_EXTENSIONS:
313
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_ext}")
314
+
315
+ # 2. Read file
316
+ file_content = await file.read()
317
+ if not file_content:
318
+ raise HTTPException(status_code=400, detail="Empty file")
319
+ if len(file_content) > MAX_FILE_SIZE_BYTES:
320
+ raise HTTPException(status_code=413, detail="File exceeds 100 MiB limit")
321
+
322
+ # 3. Generate hash
323
+ file_hash = hashlib.sha256(file_content).hexdigest()
324
+ print(f"Generated hash: {file_hash}")
325
+
326
+ # 4. Check hash via API
327
+ hash_result = await check_file_hash_exists(user_id, file_hash)
328
+ if hash_result.get("exists") is True:
329
+ return JSONResponse(
330
+ status_code=409,
331
+ content={
332
+ "message": "File already uploaded.",
333
+ "reason": "Duplicate file detected via SHA-256 hash.",
334
+ "file_hash": file_hash,
335
+ "user_id": user_id,
336
+ "action": "skipped"
337
+ }
338
+ )
339
+ print("Hash check passed. New file.")
340
+
341
+ # 5. Load DataFrame
342
+ try:
343
+ if file_ext == ".csv":
344
+ df = pd.read_csv(io.BytesIO(file_content))
345
+ elif file_ext in {".xlsx", ".xls"}:
346
+ engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
347
+ df = pd.read_excel(io.BytesIO(file_content), engine=engine)
348
+ elif file_ext == ".ods":
349
+ df = pd.read_excel(io.BytesIO(file_content), engine='odf')
350
+ except Exception as e:
351
+ raise HTTPException(status_code=400, detail=f"Failed to parse file: {str(e)}")
352
+
353
+ if len(df) > MAX_ROWS_ALLOWED:
354
+ raise HTTPException(status_code=413, detail=f"Too many rows: {len(df):,} > {MAX_ROWS_ALLOWED:,}")
355
+
356
+ # 6. Convert to Parquet
357
+ parquet_buffer = convert_df_to_parquet(df)
358
+ parquet_size = parquet_buffer.getbuffer().nbytes
359
+
360
+ # 7. Upload Parquet to S3
361
+ base_filename = os.path.splitext(file.filename)[0]
362
+ parquet_filename = f"{base_filename}.parquet"
363
+ file_key = f"{user_id}/files/datasets/{parquet_filename}"
364
+ file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
365
+
366
+ s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
367
+ ExtraArgs={'ContentType': 'application/octet-stream'})
368
+ print(f"Uploaded Parquet: {file_url}")
369
+
370
+ # 8. Generate metadata
371
+ metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
372
+ metadata.update({
373
+ "user_id": user_id,
374
+ "s3_path": file_key,
375
+ "s3_url": file_url,
376
+ "source_file": file.filename,
377
+ "source_file_type": file_ext[1:],
378
+ "file_type": "parquet",
379
+ "original_file_size_bytes": len(file_content),
380
+ "parquet_file_size_bytes": parquet_size,
381
+ "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
382
+ "file_hash": file_hash
383
+ })
384
+
385
+ # 🆕 8.5 Generate dataset preview graphs
386
+ print("Generating dataset preview graphs...")
387
+ try:
388
+ dataset_graphs = create_data_set_graphs_dict(df, max_rows=200)
389
+ metadata["data_sets_preview_graph"] = dataset_graphs
390
+ print(f"✅ Generated graphs for {len(dataset_graphs.get('columnSummaries', []))} columns")
391
+ except Exception as e:
392
+ print(f"⚠️ Failed to generate dataset graphs: {e}")
393
+ import traceback
394
+ traceback.print_exc()
395
+ metadata["data_sets_preview_graph"] = {
396
+ "error": str(e),
397
+ "columnSummaries": []
398
+ }
399
+
400
+ safe_metadata = sanitize_for_json(metadata)
401
+
402
+ # 9. Vector DB
403
+ check_vdb(user_id)
404
+ vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
405
+ vdb_success = vdb_res.get("status") == "success"
406
+ print(f"vdb_success: {vdb_success}")
407
+
408
+ # 10. PostgreSQL Metadata + file_hash
409
+ pg_result = await user_metadata_upload_pg(
410
+ user_id=user_id,
411
+ user_metadata=json.dumps(safe_metadata),
412
+ path=file_key,
413
+ url=file_url,
414
+ filename=parquet_filename,
415
+ file_type="parquet",
416
+ file_size_bytes=parquet_size,
417
+ file_hash=file_hash
418
+ )
419
+ print(f"pg_result: {pg_result}")
420
+ pg_success = pg_result.get("success", False)
421
+ print("graphs_generated", len(safe_metadata.get("data_sets_preview_graph", {}).get("columnSummaries", [])))
422
+
423
+ # 11. Return success
424
+ return {
425
+ "message": "Upload successful.",
426
+ "filename": parquet_filename,
427
+ "original_filename": file.filename,
428
+ "user_id": user_id,
429
+ "file_path": file_key,
430
+ "file_url": file_url,
431
+ "file_hash": file_hash,
432
+ "source_file_type": file_ext[1:],
433
+ "file_type": "parquet",
434
+ "original_file_size_bytes": len(file_content),
435
+ "parquet_file_size_bytes": parquet_size,
436
+ "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
437
+ "rows": len(df),
438
+ "columns": len(df.columns),
439
+ "metadata": safe_metadata,
440
+ "upload_dataset_vdb": vdb_success,
441
+ "upload_dataset_pg": pg_success,
442
+ "pg_details": pg_result,
443
+ "graphs_generated": len(safe_metadata.get("data_sets_preview_graph", {}).get("columnSummaries", []))
444
+ }
445
+
446
+ except HTTPException:
447
+ raise
448
+ except Exception as e:
449
+ print(f"Unexpected error: {e}")
450
+ import traceback
451
+ traceback.print_exc()
452
+ raise HTTPException(status_code=500, detail=str(e))
453
+ finally:
454
+ # Clean up temp directories
455
+ for d in (html_tmp_dir, bokeh_tmp_dir):
456
+ if d and os.path.exists(d):
457
+ shutil.rmtree(d, ignore_errors=True)
s3/r6.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
2
+ # from fastapi.responses import JSONResponse
3
+ # import pandas as pd
4
+ # from autoviz.AutoViz_Class import AutoViz_Class
5
+ # import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
6
+ # matplotlib.use('Agg')
7
+ # import matplotlib.pyplot as plt
8
+ # import sys
9
+ # from pathlib import Path
10
+ # from typing import List
11
+ # import httpx
12
+
13
+ # # --- Project Root Setup ---
14
+ # PROJECT_ROOT = Path(__file__).resolve().parents[1]
15
+ # if str(PROJECT_ROOT) not in sys.path:
16
+ # sys.path.insert(0, str(PROJECT_ROOT))
17
+
18
+ # from retrieve_secret import *
19
+ # from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
20
+ # from s3.create_dataset_graphs import create_data_set_graphs_dict
21
+
22
+ # # --- File Validation ---
23
+ # MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
24
+ # MAX_ROWS_ALLOWED = 1_000_000
25
+ # ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
26
+
27
+ # # --- AWS S3 Config ---
28
+ # print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
29
+ # ACCESS_KEY = AWS_S3_CREDS_KEY_ID
30
+ # SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
31
+ # BUCKET_NAME = BUCKET_NAME
32
+ # REGION_NAME = "us-east-1"
33
+
34
+ # s3 = boto3.client(
35
+ # "s3",
36
+ # aws_access_key_id=ACCESS_KEY,
37
+ # aws_secret_access_key=SECRET_KEY,
38
+ # region_name=REGION_NAME
39
+ # )
40
+
41
+ # ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
42
+
43
+ # # --- FastAPI Router ---
44
+ # s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
45
+
46
+
47
+ # # --- Helper: S3 Key ---
48
+ # def make_key(path: str, filename: str) -> str:
49
+ # return f"{path.strip('/')}/{filename}" if path else filename
50
+
51
+
52
+ # # --- Sanitize for JSON ---
53
+ # def sanitize_for_json(obj):
54
+ # if isinstance(obj, dict):
55
+ # return {k: sanitize_for_json(v) for k, v in obj.items()}
56
+ # elif isinstance(obj, list):
57
+ # return [sanitize_for_json(item) for item in obj]
58
+ # elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
59
+ # return str(obj)
60
+ # elif pd.isna(obj):
61
+ # return None
62
+ # elif isinstance(obj, (int, float, str, bool, type(None))):
63
+ # return obj
64
+ # else:
65
+ # return str(obj)
66
+
67
+
68
+ # # --- Vector DB Placeholders ---
69
+ # def check_vdb(user_id: str):
70
+ # print(f"Checking VDB for user: {user_id}")
71
+
72
+ # async def add_metadata_only(collection_name: str, metadata: dict):
73
+ # print(f"Adding metadata to collection: {collection_name}")
74
+ # return {"status": "success", "collection": collection_name}
75
+
76
+
77
+ # # --- AutoViz HTML ---
78
+ # def run_autoviz_html(
79
+ # dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
80
+ # max_rows_analyzed=150000, max_cols_analyzed=30
81
+ # ):
82
+ # tmp_dir = tempfile.mkdtemp()
83
+ # AV = AutoViz_Class()
84
+ # AV.AutoViz(
85
+ # filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
86
+ # verbose=verbose, lowess=lowess, chart_format="html",
87
+ # max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
88
+ # save_plot_dir=tmp_dir
89
+ # )
90
+ # print(f"HTML plots saved to: {tmp_dir}")
91
+ # return tmp_dir
92
+
93
+
94
+ # # --- AutoViz Bokeh ---
95
+ # def run_autoviz_bokeh(
96
+ # dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
97
+ # max_rows_analyzed=150000, max_cols_analyzed=30
98
+ # ):
99
+ # tmp_dir = tempfile.mkdtemp()
100
+ # print(f"Bokeh temp directory: {tmp_dir}")
101
+ # try:
102
+ # matplotlib.use('Agg')
103
+ # plt.ioff()
104
+ # AV = AutoViz_Class()
105
+ # AV.AutoViz(
106
+ # filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
107
+ # verbose=verbose, lowess=lowess, chart_format="bokeh",
108
+ # max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
109
+ # save_plot_dir=tmp_dir
110
+ # )
111
+ # plt.close('all')
112
+ # print(f"Bokeh plots saved to: {tmp_dir}")
113
+ # return tmp_dir
114
+ # except Exception as e:
115
+ # print(f"Error in Bokeh: {e}")
116
+ # import traceback
117
+ # traceback.print_exc()
118
+ # raise
119
+
120
+
121
+ # # --- Fallback Matplotlib ---
122
+ # def generate_matplotlib_plots(df, user_id, base_filename):
123
+ # print("\nGenerating matplotlib fallback plots...")
124
+ # tmp_dir = tempfile.mkdtemp()
125
+ # plot_metadata = []
126
+ # try:
127
+ # matplotlib.use('Agg')
128
+ # numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
129
+ # if len(numeric_cols) == 0:
130
+ # print("No numeric columns for plotting")
131
+ # return []
132
+
133
+ # # Correlation heatmap
134
+ # if len(numeric_cols) > 1:
135
+ # plt.figure(figsize=(10, 8))
136
+ # import seaborn as sns
137
+ # corr = df[numeric_cols].corr()
138
+ # sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
139
+ # plt.title('Correlation Heatmap')
140
+ # plt.tight_layout()
141
+ # png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
142
+ # plt.savefig(png_path, dpi=100, bbox_inches='tight')
143
+ # plt.close()
144
+
145
+ # s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
146
+ # with open(png_path, 'rb') as f:
147
+ # s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
148
+ # plot_metadata.append({
149
+ # "file_name": "correlation_heatmap.png",
150
+ # "s3_path": s3_key,
151
+ # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
152
+ # "type": "png",
153
+ # "plot_type": "heatmap"
154
+ # })
155
+
156
+ # # Histograms
157
+ # for col in numeric_cols[:5]:
158
+ # plt.figure(figsize=(10, 6))
159
+ # df[col].hist(bins=30, edgecolor='black')
160
+ # plt.title(f'Distribution of {col}')
161
+ # plt.xlabel(col)
162
+ # plt.ylabel('Frequency')
163
+ # plt.tight_layout()
164
+ # png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
165
+ # plt.savefig(png_path, dpi=100, bbox_inches='tight')
166
+ # plt.close()
167
+
168
+ # s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
169
+ # with open(png_path, 'rb') as f:
170
+ # s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
171
+ # plot_metadata.append({
172
+ # "file_name": f"distribution_{col}.png",
173
+ # "s3_path": s3_key,
174
+ # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
175
+ # "type": "png",
176
+ # "plot_type": "histogram"
177
+ # })
178
+ # return plot_metadata
179
+ # except Exception as e:
180
+ # print(f"Matplotlib failed: {e}")
181
+ # return []
182
+ # finally:
183
+ # shutil.rmtree(tmp_dir, ignore_errors=True)
184
+
185
+
186
+ # # --- Upload Viz Files ---
187
+ # def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
188
+ # patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
189
+ # files = []
190
+ # for p in patterns:
191
+ # files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
192
+ # files = list(set(files))
193
+ # if not files:
194
+ # print(f"No {viz_type} files found")
195
+ # return []
196
+
197
+ # folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
198
+ # metadata = []
199
+ # content_type_map = {
200
+ # 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
201
+ # 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
202
+ # }
203
+
204
+ # for file_path in files:
205
+ # name = os.path.basename(file_path)
206
+ # ext = os.path.splitext(name)[1][1:]
207
+ # s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
208
+ # with open(file_path, "rb") as f:
209
+ # body = f.read()
210
+ # s3.put_object(
211
+ # Bucket=BUCKET_NAME, Key=s3_key, Body=body,
212
+ # ContentType=content_type_map.get(ext, 'application/octet-stream')
213
+ # )
214
+ # metadata.append({
215
+ # "file_name": name, "s3_path": s3_key,
216
+ # "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
217
+ # "type": viz_type, "size": len(body)
218
+ # })
219
+ # print(f"{viz_type.upper()} uploaded: {s3_key}")
220
+ # return metadata
221
+
222
+
223
+ # # --- Convert to Parquet ---
224
+ # def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
225
+ # buffer = io.BytesIO()
226
+ # df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
227
+ # buffer.seek(0)
228
+ # return buffer
229
+
230
+
231
+ # # --- Check File Hash Exists (FIXED) ---
232
+ # async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
233
+ # """
234
+ # Check if a file hash already exists for a user.
235
+ # Returns a dict with 'success', 'exists', and optional 'data' keys.
236
+ # """
237
+ # url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
238
+ # headers = {"accept": "application/json"}
239
+ # async with httpx.AsyncClient(timeout=10.0) as client:
240
+ # try:
241
+ # r = await client.get(url, headers=headers)
242
+ # r.raise_for_status()
243
+ # result = r.json()
244
+
245
+ # # Log the actual response for debugging
246
+ # print(f"Hash check API response: {result}")
247
+
248
+ # # Check for 'exists' field first, then parse message if not present
249
+ # exists = result.get("exists", None)
250
+
251
+ # # If 'exists' field is not present, parse the message
252
+ # if exists is None and "message" in result:
253
+ # message_lower = result["message"].lower()
254
+ # # Check if message indicates file already exists
255
+ # exists = "already existed" in message_lower or "already exists" in message_lower or "duplicate" in message_lower
256
+
257
+ # # Default to False if still None
258
+ # if exists is None:
259
+ # exists = False
260
+
261
+ # return {
262
+ # "success": True,
263
+ # "exists": exists,
264
+ # "data": result
265
+ # }
266
+ # except httpx.HTTPStatusError as e:
267
+ # print(f"Hash check HTTP error: {e.response.status_code} - {e.response.text}")
268
+ # return {
269
+ # "success": False,
270
+ # "exists": False, # Assume not exists on error to be safe
271
+ # "message": f"HTTP {e.response.status_code}",
272
+ # "error_detail": e.response.text
273
+ # }
274
+ # except Exception as e:
275
+ # print(f"Hash check exception: {str(e)}")
276
+ # return {
277
+ # "success": False,
278
+ # "exists": False, # Assume not exists on error
279
+ # "message": f"Request failed: {str(e)}"
280
+ # }
281
+
282
+
283
+ # # --- PostgreSQL Metadata Upload with file_hash ---
284
+ # async def user_metadata_upload_pg(
285
+ # user_id: str,
286
+ # user_metadata: str,
287
+ # path: str,
288
+ # url: str,
289
+ # filename: str,
290
+ # file_type: str,
291
+ # file_size_bytes: int,
292
+ # file_hash: str,
293
+ # timeout: float = 10.0
294
+ # ):
295
+ # payload = {
296
+ # "user_id": user_id,
297
+ # "user_metadata": user_metadata,
298
+ # "path": path,
299
+ # "url": url,
300
+ # "filename": filename,
301
+ # "file_type": file_type,
302
+ # "file_size_bytes": file_size_bytes,
303
+ # "file_hash": file_hash
304
+ # }
305
+ # print(f"PostgreSQL payload file_hash: {payload['file_hash']}")
306
+ # async with httpx.AsyncClient() as client:
307
+ # try:
308
+ # r = await client.post(
309
+ # "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
310
+ # json=payload,
311
+ # timeout=timeout
312
+ # )
313
+ # r.raise_for_status()
314
+ # result = r.json()
315
+ # result["file_hash"] = file_hash
316
+ # print(f"PostgreSQL result file_hash: {result['file_hash']}")
317
+ # return {"success": True, "data": result}
318
+ # except httpx.HTTPStatusError as e:
319
+ # return {
320
+ # "success": False,
321
+ # "error": "HTTP error",
322
+ # "status_code": e.response.status_code,
323
+ # "detail": e.response.text,
324
+ # "file_hash": file_hash
325
+ # }
326
+ # except Exception as e:
327
+ # return {
328
+ # "success": False,
329
+ # "error": "Request failed",
330
+ # "detail": str(e),
331
+ # "file_hash": file_hash
332
+ # }
333
+
334
+
335
+ # # --- DEBUG ENDPOINT (Optional - for troubleshooting) ---
336
+ # @s3_bucket_router1.get("/debug/check_hash/{user_id}/{file_hash}")
337
+ # async def debug_check_hash(user_id: str, file_hash: str):
338
+ # """Debug endpoint to test hash checking"""
339
+ # result = await check_file_hash_exists(user_id, file_hash)
340
+ # return {
341
+ # "raw_result": result,
342
+ # "exists_value": result.get("exists"),
343
+ # "exists_type": type(result.get("exists")).__name__,
344
+ # "success_value": result.get("success"),
345
+ # "interpretation": "File exists" if result.get("exists") is True else "File does not exist or check failed"
346
+ # }
347
+
348
+
349
+ # # --- MAIN ENDPOINT WITH FIXED HASH CHECK AND DATASET GRAPHS ---
350
+ # @s3_bucket_router1.post("/upload_datasets_v3/")
351
+ # async def upload_file(
352
+ # file: UploadFile = File(...),
353
+ # user_id: str = Query(..., description="User ID"),
354
+ # path: str = Query("", description="Optional subpath")
355
+ # ):
356
+ # html_tmp_dir = None
357
+ # bokeh_tmp_dir = None
358
+ # file_content = None
359
+
360
+ # try:
361
+ # # 1. Validate extension
362
+ # file_ext = os.path.splitext(file.filename)[1].lower()
363
+ # if file_ext not in ALLOWED_EXTENSIONS:
364
+ # raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_ext}")
365
+
366
+ # # 2. Read file
367
+ # file_content = await file.read()
368
+ # if not file_content:
369
+ # raise HTTPException(status_code=400, detail="Empty file")
370
+ # if len(file_content) > MAX_FILE_SIZE_BYTES:
371
+ # raise HTTPException(status_code=413, detail="File exceeds 100 MiB limit")
372
+
373
+ # # 3. Generate hash
374
+ # file_hash = hashlib.sha256(file_content).hexdigest()
375
+ # print(f"Generated file hash: {file_hash}")
376
+
377
+ # # 4. Check hash via API (FIXED LOGIC)
378
+ # hash_result = await check_file_hash_exists(user_id, file_hash)
379
+
380
+ # # Enhanced logging
381
+ # print(f"Hash check result: {hash_result}")
382
+
383
+ # # Check if the API call was successful first
384
+ # if not hash_result.get("success", False):
385
+ # print(f"⚠️ Warning: Hash check API failed: {hash_result.get('message')}")
386
+ # # You can decide to fail here if hash check is critical
387
+ # # raise HTTPException(status_code=503, detail="Hash check service unavailable")
388
+
389
+ # # Check if file exists
390
+ # if hash_result.get("exists") is True:
391
+ # print(f"🚫 Duplicate file detected: {file_hash}")
392
+ # return JSONResponse(
393
+ # status_code=409,
394
+ # content={
395
+ # "message": "File already uploaded.",
396
+ # "reason": "Duplicate file detected via SHA-256 hash.",
397
+ # "file_hash": file_hash,
398
+ # "user_id": user_id,
399
+ # "filename": file.filename,
400
+ # "action": "skipped",
401
+ # "existing_file_info": hash_result.get("data")
402
+ # }
403
+ # )
404
+
405
+ # print("✅ Hash check passed. New file - proceeding with upload.")
406
+
407
+ # # 5. Load DataFrame
408
+ # try:
409
+ # if file_ext == ".csv":
410
+ # df = pd.read_csv(io.BytesIO(file_content))
411
+ # elif file_ext in {".xlsx", ".xls"}:
412
+ # engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
413
+ # df = pd.read_excel(io.BytesIO(file_content), engine=engine)
414
+ # elif file_ext == ".ods":
415
+ # df = pd.read_excel(io.BytesIO(file_content), engine='odf')
416
+ # except Exception as e:
417
+ # raise HTTPException(status_code=400, detail=f"Failed to parse file: {str(e)}")
418
+
419
+ # if len(df) > MAX_ROWS_ALLOWED:
420
+ # raise HTTPException(status_code=413, detail=f"Too many rows: {len(df):,} > {MAX_ROWS_ALLOWED:,}")
421
+
422
+ # # 6. Convert to Parquet
423
+ # parquet_buffer = convert_df_to_parquet(df)
424
+ # parquet_size = parquet_buffer.getbuffer().nbytes
425
+
426
+ # # 7. Upload Parquet to S3
427
+ # base_filename = os.path.splitext(file.filename)[0]
428
+ # parquet_filename = f"{base_filename}.parquet"
429
+ # file_key = f"{user_id}/files/datasets/{parquet_filename}"
430
+ # file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
431
+
432
+ # s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
433
+ # ExtraArgs={'ContentType': 'application/octet-stream'})
434
+ # print(f"Uploaded Parquet: {file_url}")
435
+
436
+ # # 8. Generate metadata
437
+ # metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
438
+ # metadata.update({
439
+ # "user_id": user_id,
440
+ # "s3_path": file_key,
441
+ # "s3_url": file_url,
442
+ # "source_file": file.filename,
443
+ # "source_file_type": file_ext[1:],
444
+ # "file_type": "parquet",
445
+ # "original_file_size_bytes": len(file_content),
446
+ # "parquet_file_size_bytes": parquet_size,
447
+ # "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
448
+ # "file_hash": file_hash
449
+ # })
450
+
451
+ # # 8.5 Generate dataset preview graphs
452
+ # print("Generating dataset preview graphs...")
453
+ # try:
454
+ # dataset_graphs = create_data_set_graphs_dict(df, max_rows=200)
455
+ # metadata["data_sets_preview_graph"] = dataset_graphs
456
+ # print(f"✅ Generated graphs for {len(dataset_graphs.get('columnSummaries', []))} columns")
457
+ # except Exception as e:
458
+ # print(f"⚠️ Failed to generate dataset graphs: {e}")
459
+ # import traceback
460
+ # traceback.print_exc()
461
+ # metadata["data_sets_preview_graph"] = {
462
+ # "error": str(e),
463
+ # "columnSummaries": []
464
+ # }
465
+
466
+ # safe_metadata = sanitize_for_json(metadata)
467
+
468
+ # # 9. Vector DB
469
+ # check_vdb(user_id)
470
+ # vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
471
+ # vdb_success = vdb_res.get("status") == "success"
472
+ # print(f"VDB upload success: {vdb_success}")
473
+
474
+ # # 10. PostgreSQL Metadata + file_hash
475
+ # pg_result = await user_metadata_upload_pg(
476
+ # user_id=user_id,
477
+ # user_metadata=json.dumps(safe_metadata),
478
+ # path=file_key,
479
+ # url=file_url,
480
+ # filename=parquet_filename,
481
+ # file_type="parquet",
482
+ # file_size_bytes=parquet_size,
483
+ # file_hash=file_hash
484
+ # )
485
+ # print(f"PostgreSQL upload result: {pg_result}")
486
+ # pg_success = pg_result.get("success", False)
487
+
488
+ # graphs_count = len(safe_metadata.get("data_sets_preview_graph", {}).get("columnSummaries", []))
489
+ # print(f"Graphs generated: {graphs_count}")
490
+
491
+ # # 11. Return success
492
+ # return {
493
+ # "message": "Upload successful.",
494
+ # "filename": parquet_filename,
495
+ # "original_filename": file.filename,
496
+ # "user_id": user_id,
497
+ # "file_path": file_key,
498
+ # "file_url": file_url,
499
+ # "file_hash": file_hash,
500
+ # "source_file_type": file_ext[1:],
501
+ # "file_type": "parquet",
502
+ # "original_file_size_bytes": len(file_content),
503
+ # "parquet_file_size_bytes": parquet_size,
504
+ # "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
505
+ # "rows": len(df),
506
+ # "columns": len(df.columns),
507
+ # "metadata": safe_metadata,
508
+ # "upload_dataset_vdb": vdb_success,
509
+ # "upload_dataset_pg": pg_success,
510
+ # "pg_details": pg_result,
511
+ # "graphs_generated": graphs_count
512
+ # }
513
+
514
+ # except HTTPException:
515
+ # raise
516
+ # except Exception as e:
517
+ # print(f"Unexpected error: {e}")
518
+ # import traceback
519
+ # traceback.print_exc()
520
+ # raise HTTPException(status_code=500, detail=str(e))
521
+ # finally:
522
+ # # Clean up temp directories
523
+ # for d in (html_tmp_dir, bokeh_tmp_dir):
524
+ # if d and os.path.exists(d):
525
+ # shutil.rmtree(d, ignore_errors=True)
526
+
527
+
528
+
529
+
530
+
531
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query, APIRouter
532
+ from fastapi.responses import JSONResponse
533
+ import pandas as pd
534
+ from autoviz.AutoViz_Class import AutoViz_Class
535
+ import io, os, boto3, tempfile, glob, matplotlib, json, hashlib, shutil
536
+ matplotlib.use('Agg')
537
+ import matplotlib.pyplot as plt
538
+ import sys
539
+ from pathlib import Path
540
+ from typing import List
541
+ import httpx
542
+
543
+ # --- Project Root Setup ---
544
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
545
+ if str(PROJECT_ROOT) not in sys.path:
546
+ sys.path.insert(0, str(PROJECT_ROOT))
547
+
548
+ from retrieve_secret import *
549
+ from s3.meta_data_creation_from_s3 import create_file_metadata_from_df
550
+ from s3.create_dataset_graphs import create_data_set_graphs_dict
551
+
552
+ # --- File Validation ---
553
+ MAX_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MiB
554
+ MAX_ROWS_ALLOWED = 1_000_000
555
+ ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".ods"}
556
+
557
+ # --- AWS S3 Config ---
558
+ print("AWS S3 config:", AWS_S3_CREDS_KEY_ID, AWS_S3_CREDS_SECRET_KEY, BUCKET_NAME)
559
+ ACCESS_KEY = AWS_S3_CREDS_KEY_ID
560
+ SECRET_KEY = AWS_S3_CREDS_SECRET_KEY
561
+ BUCKET_NAME = BUCKET_NAME
562
+ REGION_NAME = "us-east-1"
563
+
564
+ s3 = boto3.client(
565
+ "s3",
566
+ aws_access_key_id=ACCESS_KEY,
567
+ aws_secret_access_key=SECRET_KEY,
568
+ region_name=REGION_NAME
569
+ )
570
+
571
+ ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com"
572
+
573
+ # --- FastAPI Router ---
574
+ s3_bucket_router1 = APIRouter(prefix="/s3/v3", tags=["s3_v3"])
575
+
576
+
577
+ # --- Helper: S3 Key ---
578
+ def make_key(path: str, filename: str) -> str:
579
+ return f"{path.strip('/')}/{filename}" if path else filename
580
+
581
+
582
+ # --- Sanitize for JSON ---
583
+ def sanitize_for_json(obj):
584
+ if isinstance(obj, dict):
585
+ return {k: sanitize_for_json(v) for k, v in obj.items()}
586
+ elif isinstance(obj, list):
587
+ return [sanitize_for_json(item) for item in obj]
588
+ elif isinstance(obj, (pd.Timestamp, pd.DatetimeTZDtype)):
589
+ return str(obj)
590
+ elif pd.isna(obj):
591
+ return None
592
+ elif isinstance(obj, (int, float, str, bool, type(None))):
593
+ return obj
594
+ else:
595
+ return str(obj)
596
+
597
+
598
+ # --- Vector DB Placeholders ---
599
+ def check_vdb(user_id: str):
600
+ print(f"Checking VDB for user: {user_id}")
601
+
602
+ async def add_metadata_only(collection_name: str, metadata: dict):
603
+ print(f"Adding metadata to collection: {collection_name}")
604
+ return {"status": "success", "collection": collection_name}
605
+
606
+
607
+ # --- AutoViz HTML ---
608
+ def run_autoviz_html(
609
+ dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
610
+ max_rows_analyzed=150000, max_cols_analyzed=30
611
+ ):
612
+ tmp_dir = tempfile.mkdtemp()
613
+ AV = AutoViz_Class()
614
+ AV.AutoViz(
615
+ filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
616
+ verbose=verbose, lowess=lowess, chart_format="html",
617
+ max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
618
+ save_plot_dir=tmp_dir
619
+ )
620
+ print(f"HTML plots saved to: {tmp_dir}")
621
+ return tmp_dir
622
+
623
+
624
+ # --- AutoViz Bokeh ---
625
+ def run_autoviz_bokeh(
626
+ dfte, depVar="", sep=",", header=0, verbose=0, lowess=False,
627
+ max_rows_analyzed=150000, max_cols_analyzed=30
628
+ ):
629
+ tmp_dir = tempfile.mkdtemp()
630
+ print(f"Bokeh temp directory: {tmp_dir}")
631
+ try:
632
+ matplotlib.use('Agg')
633
+ plt.ioff()
634
+ AV = AutoViz_Class()
635
+ AV.AutoViz(
636
+ filename="", sep=sep, depVar=depVar, dfte=dfte, header=header,
637
+ verbose=verbose, lowess=lowess, chart_format="bokeh",
638
+ max_rows_analyzed=max_rows_analyzed, max_cols_analyzed=max_cols_analyzed,
639
+ save_plot_dir=tmp_dir
640
+ )
641
+ plt.close('all')
642
+ print(f"Bokeh plots saved to: {tmp_dir}")
643
+ return tmp_dir
644
+ except Exception as e:
645
+ print(f"Error in Bokeh: {e}")
646
+ import traceback
647
+ traceback.print_exc()
648
+ raise
649
+
650
+
651
+ # --- Fallback Matplotlib ---
652
+ def generate_matplotlib_plots(df, user_id, base_filename):
653
+ print("\nGenerating matplotlib fallback plots...")
654
+ tmp_dir = tempfile.mkdtemp()
655
+ plot_metadata = []
656
+ try:
657
+ matplotlib.use('Agg')
658
+ numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
659
+ if len(numeric_cols) == 0:
660
+ print("No numeric columns for plotting")
661
+ return []
662
+
663
+ # Correlation heatmap
664
+ if len(numeric_cols) > 1:
665
+ plt.figure(figsize=(10, 8))
666
+ import seaborn as sns
667
+ corr = df[numeric_cols].corr()
668
+ sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
669
+ plt.title('Correlation Heatmap')
670
+ plt.tight_layout()
671
+ png_path = os.path.join(tmp_dir, 'correlation_heatmap.png')
672
+ plt.savefig(png_path, dpi=100, bbox_inches='tight')
673
+ plt.close()
674
+
675
+ s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/correlation_heatmap.png"
676
+ with open(png_path, 'rb') as f:
677
+ s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
678
+ plot_metadata.append({
679
+ "file_name": "correlation_heatmap.png",
680
+ "s3_path": s3_key,
681
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
682
+ "type": "png",
683
+ "plot_type": "heatmap"
684
+ })
685
+
686
+ # Histograms
687
+ for col in numeric_cols[:5]:
688
+ plt.figure(figsize=(10, 6))
689
+ df[col].hist(bins=30, edgecolor='black')
690
+ plt.title(f'Distribution of {col}')
691
+ plt.xlabel(col)
692
+ plt.ylabel('Frequency')
693
+ plt.tight_layout()
694
+ png_path = os.path.join(tmp_dir, f'distribution_{col}.png')
695
+ plt.savefig(png_path, dpi=100, bbox_inches='tight')
696
+ plt.close()
697
+
698
+ s3_key = f"{user_id}/files/datasets/{base_filename}/pngs/distribution_{col}.png"
699
+ with open(png_path, 'rb') as f:
700
+ s3.put_object(Bucket=BUCKET_NAME, Key=s3_key, Body=f.read(), ContentType='image/png')
701
+ plot_metadata.append({
702
+ "file_name": f"distribution_{col}.png",
703
+ "s3_path": s3_key,
704
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
705
+ "type": "png",
706
+ "plot_type": "histogram"
707
+ })
708
+ return plot_metadata
709
+ except Exception as e:
710
+ print(f"Matplotlib failed: {e}")
711
+ return []
712
+ finally:
713
+ shutil.rmtree(tmp_dir, ignore_errors=True)
714
+
715
+
716
+ # --- Upload Viz Files ---
717
+ def upload_viz_files_to_s3(tmp_dir, file_pattern, user_id, base_filename, viz_type):
718
+ patterns = [file_pattern, f"*.{viz_type}", f"**/*.{viz_type}"]
719
+ files = []
720
+ for p in patterns:
721
+ files.extend(glob.glob(os.path.join(tmp_dir, p), recursive=True))
722
+ files = list(set(files))
723
+ if not files:
724
+ print(f"No {viz_type} files found")
725
+ return []
726
+
727
+ folder = "htmls" if viz_type == "html" else "svgs" if viz_type == "svg" else "pngs"
728
+ metadata = []
729
+ content_type_map = {
730
+ 'html': 'text/html', 'svg': 'image/svg+xml', 'png': 'image/png',
731
+ 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'
732
+ }
733
+
734
+ for file_path in files:
735
+ name = os.path.basename(file_path)
736
+ ext = os.path.splitext(name)[1][1:]
737
+ s3_key = f"{user_id}/files/datasets/{base_filename}/{folder}/{name}"
738
+ with open(file_path, "rb") as f:
739
+ body = f.read()
740
+ s3.put_object(
741
+ Bucket=BUCKET_NAME, Key=s3_key, Body=body,
742
+ ContentType=content_type_map.get(ext, 'application/octet-stream')
743
+ )
744
+ metadata.append({
745
+ "file_name": name, "s3_path": s3_key,
746
+ "url": f"{ENDPOINT_URL}/{BUCKET_NAME}/{s3_key}",
747
+ "type": viz_type, "size": len(body)
748
+ })
749
+ print(f"{viz_type.upper()} uploaded: {s3_key}")
750
+ return metadata
751
+
752
+
753
+ # --- Convert to Parquet ---
754
+ def convert_df_to_parquet(df: pd.DataFrame) -> io.BytesIO:
755
+ buffer = io.BytesIO()
756
+ df.to_parquet(buffer, engine='pyarrow', compression='snappy', index=False)
757
+ buffer.seek(0)
758
+ return buffer
759
+
760
+
761
+ # --- Check File Hash Exists (FIXED) ---
762
+ async def check_file_hash_exists(user_id: str, file_hash: str) -> dict:
763
+ """
764
+ Check if a file hash already exists for a user.
765
+ Returns a dict with 'success', 'exists', and optional 'data' keys.
766
+ """
767
+ url = f"https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadata/{user_id}/check_file_hash?file_hash={file_hash}"
768
+ headers = {"accept": "application/json"}
769
+ async with httpx.AsyncClient(timeout=10.0) as client:
770
+ try:
771
+ r = await client.get(url, headers=headers)
772
+ r.raise_for_status()
773
+ result = r.json()
774
+
775
+ # Log the actual response for debugging
776
+ print(f"Hash check API response: {result}")
777
+
778
+ # Check for 'exists' field first, then parse message if not present
779
+ exists = result.get("exists", None)
780
+
781
+ # If 'exists' field is not present, parse the message
782
+ if exists is None and "message" in result:
783
+ message_lower = result["message"].lower()
784
+ # Check if message indicates file already exists
785
+ exists = "already existed" in message_lower or "already exists" in message_lower or "duplicate" in message_lower
786
+
787
+ # Default to False if still None
788
+ if exists is None:
789
+ exists = False
790
+
791
+ return {
792
+ "success": True,
793
+ "exists": exists,
794
+ "data": result
795
+ }
796
+ except httpx.HTTPStatusError as e:
797
+ print(f"Hash check HTTP error: {e.response.status_code} - {e.response.text}")
798
+ return {
799
+ "success": False,
800
+ "exists": False, # Assume not exists on error to be safe
801
+ "message": f"HTTP {e.response.status_code}",
802
+ "error_detail": e.response.text
803
+ }
804
+ except Exception as e:
805
+ print(f"Hash check exception: {str(e)}")
806
+ return {
807
+ "success": False,
808
+ "exists": False, # Assume not exists on error
809
+ "message": f"Request failed: {str(e)}"
810
+ }
811
+
812
+
813
+ # --- PostgreSQL Metadata Upload with file_hash ---
814
+ async def user_metadata_upload_pg(
815
+ user_id: str,
816
+ user_metadata: str,
817
+ path: str,
818
+ url: str,
819
+ filename: str,
820
+ file_type: str,
821
+ file_size_bytes: int,
822
+ file_hash: str,
823
+ timeout: float = 10.0,
824
+ data_sets_preview_graph: str = None
825
+
826
+ ):
827
+ payload = {
828
+ "user_id": user_id,
829
+ "user_metadata": user_metadata,
830
+ "path": path,
831
+ "url": url,
832
+ "filename": filename,
833
+ "file_type": file_type,
834
+ "file_size_bytes": file_size_bytes,
835
+ "file_hash": file_hash,
836
+ "data_sets_preview_graph": data_sets_preview_graph
837
+ }
838
+ print(f"PostgreSQL payload file_hash: {payload['file_hash']}")
839
+ print("data_sets_preview_graph", data_sets_preview_graph)
840
+ async with httpx.AsyncClient() as client:
841
+ try:
842
+ r = await client.post(
843
+ "https://mr-mvp-api-dev.dev.ingenspark.com/auth/UserMetadataCreate",
844
+ json=payload,
845
+ timeout=timeout
846
+ )
847
+ r.raise_for_status()
848
+ result = r.json()
849
+ result["file_hash"] = file_hash
850
+ print(f"PostgreSQL result file_hash: {result['file_hash']}")
851
+ return {"success": True, "data": result}
852
+ except httpx.HTTPStatusError as e:
853
+ return {
854
+ "success": False,
855
+ "error": "HTTP error",
856
+ "status_code": e.response.status_code,
857
+ "detail": e.response.text,
858
+ "file_hash": file_hash
859
+ }
860
+ except Exception as e:
861
+ return {
862
+ "success": False,
863
+ "error": "Request failed",
864
+ "detail": str(e),
865
+ "file_hash": file_hash
866
+ }
867
+
868
+
869
+ # --- DEBUG ENDPOINT (Optional - for troubleshooting) ---
870
+ @s3_bucket_router1.get("/debug/check_hash/{user_id}/{file_hash}")
871
+ async def debug_check_hash(user_id: str, file_hash: str):
872
+ """Debug endpoint to test hash checking"""
873
+ result = await check_file_hash_exists(user_id, file_hash)
874
+ return {
875
+ "raw_result": result,
876
+ "exists_value": result.get("exists"),
877
+ "exists_type": type(result.get("exists")).__name__,
878
+ "success_value": result.get("success"),
879
+ "interpretation": "File exists" if result.get("exists") is True else "File does not exist or check failed"
880
+ }
881
+
882
+
883
+ # --- MAIN ENDPOINT WITH FIXED HASH CHECK AND DATASET GRAPHS ---
884
+ @s3_bucket_router1.post("/upload_datasets_v3/")
885
+ async def upload_file(
886
+ file: UploadFile = File(...),
887
+ user_id: str = Query(..., description="User ID"),
888
+ path: str = Query("", description="Optional subpath")
889
+ ):
890
+ html_tmp_dir = None
891
+ bokeh_tmp_dir = None
892
+ file_content = None
893
+
894
+ try:
895
+ # 1. Validate extension
896
+ file_ext = os.path.splitext(file.filename)[1].lower()
897
+ if file_ext not in ALLOWED_EXTENSIONS:
898
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_ext}")
899
+
900
+ # 2. Read file
901
+ file_content = await file.read()
902
+ if not file_content:
903
+ raise HTTPException(status_code=400, detail="Empty file")
904
+ if len(file_content) > MAX_FILE_SIZE_BYTES:
905
+ raise HTTPException(status_code=413, detail="File exceeds 100 MiB limit")
906
+
907
+ # 3. Generate hash
908
+ file_hash = hashlib.sha256(file_content).hexdigest()
909
+ print(f"Generated file hash: {file_hash}")
910
+
911
+ # 4. Check hash via API (FIXED LOGIC)
912
+ hash_result = await check_file_hash_exists(user_id, file_hash)
913
+
914
+ # Enhanced logging
915
+ print(f"Hash check result: {hash_result}")
916
+
917
+ # Check if the API call was successful first
918
+ if not hash_result.get("success", False):
919
+ print(f"⚠️ Warning: Hash check API failed: {hash_result.get('message')}")
920
+ # You can decide to fail here if hash check is critical
921
+ # raise HTTPException(status_code=503, detail="Hash check service unavailable")
922
+
923
+ # Check if file exists
924
+ if hash_result.get("exists") is True:
925
+ print(f"🚫 Duplicate file detected: {file_hash}")
926
+ return JSONResponse(
927
+ status_code=409,
928
+ content={
929
+ "message": "File already uploaded.",
930
+ "reason": "Duplicate file detected via SHA-256 hash.",
931
+ "file_hash": file_hash,
932
+ "user_id": user_id,
933
+ "filename": file.filename,
934
+ "action": "skipped",
935
+ "existing_file_info": hash_result.get("data")
936
+ }
937
+ )
938
+
939
+ print("✅ Hash check passed. New file - proceeding with upload.")
940
+
941
+ # 5. Load DataFrame
942
+ try:
943
+ if file_ext == ".csv":
944
+ df = pd.read_csv(io.BytesIO(file_content))
945
+ elif file_ext in {".xlsx", ".xls"}:
946
+ engine = 'openpyxl' if file_ext == ".xlsx" else 'xlrd'
947
+ df = pd.read_excel(io.BytesIO(file_content), engine=engine)
948
+ elif file_ext == ".ods":
949
+ df = pd.read_excel(io.BytesIO(file_content), engine='odf')
950
+ except Exception as e:
951
+ raise HTTPException(status_code=400, detail=f"Failed to parse file: {str(e)}")
952
+
953
+ if len(df) > MAX_ROWS_ALLOWED:
954
+ raise HTTPException(status_code=413, detail=f"Too many rows: {len(df):,} > {MAX_ROWS_ALLOWED:,}")
955
+
956
+ # 6. Convert to Parquet
957
+ parquet_buffer = convert_df_to_parquet(df)
958
+ parquet_size = parquet_buffer.getbuffer().nbytes
959
+
960
+ # 7. Upload Parquet to S3
961
+ base_filename = os.path.splitext(file.filename)[0]
962
+ parquet_filename = f"{base_filename}.parquet"
963
+ file_key = f"{user_id}/files/datasets/{parquet_filename}"
964
+ file_url = f"{ENDPOINT_URL}/{BUCKET_NAME}/{file_key}"
965
+
966
+ s3.upload_fileobj(parquet_buffer, BUCKET_NAME, file_key,
967
+ ExtraArgs={'ContentType': 'application/octet-stream'})
968
+ print(f"Uploaded Parquet: {file_url}")
969
+
970
+ # 8. Generate metadata
971
+ metadata = create_file_metadata_from_df(df, parquet_filename, file_key)
972
+ metadata.update({
973
+ "user_id": user_id,
974
+ "s3_path": file_key,
975
+ "s3_url": file_url,
976
+ "source_file": file.filename,
977
+ "source_file_type": file_ext[1:],
978
+ "file_type": "parquet",
979
+ "original_file_size_bytes": len(file_content),
980
+ "parquet_file_size_bytes": parquet_size,
981
+ "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
982
+ "file_hash": file_hash
983
+ })
984
+
985
+ # 8.5 Generate dataset preview graphs (separate from metadata)
986
+ print("Generating dataset preview graphs...")
987
+ dataset_graphs = None
988
+ try:
989
+ dataset_graphs = create_data_set_graphs_dict(df, max_rows=200)
990
+ print(f"✅ Generated graphs for {len(dataset_graphs.get('columnSummaries', []))} columns")
991
+ except Exception as e:
992
+ print(f"⚠️ Failed to generate dataset graphs: {e}")
993
+ import traceback
994
+ traceback.print_exc()
995
+ dataset_graphs = {
996
+ "error": str(e),
997
+ "columnSummaries": []
998
+ }
999
+
1000
+ safe_metadata = sanitize_for_json(metadata)
1001
+ safe_dataset_graphs = sanitize_for_json(dataset_graphs) if dataset_graphs else {"columnSummaries": []}
1002
+
1003
+ print(f"safe_dataset_graphs: {safe_dataset_graphs}")
1004
+
1005
+ # 9. Vector DB
1006
+ check_vdb(user_id)
1007
+ vdb_res = await add_metadata_only("sri_1_files_&_files_metadata", metadata)
1008
+ vdb_success = vdb_res.get("status") == "success"
1009
+ print(f"VDB upload success: {vdb_success}")
1010
+
1011
+ # 10. PostgreSQL Metadata + file_hash
1012
+ pg_result = await user_metadata_upload_pg(
1013
+ user_id=user_id,
1014
+ user_metadata=json.dumps(safe_metadata),
1015
+ path=file_key,
1016
+ url=file_url,
1017
+ filename=parquet_filename,
1018
+ file_type="parquet",
1019
+ file_size_bytes=parquet_size,
1020
+ file_hash=file_hash,
1021
+ data_sets_preview_graph=safe_dataset_graphs
1022
+
1023
+ )
1024
+ print(f"PostgreSQL upload result: {pg_result}")
1025
+ pg_success = pg_result.get("success", False)
1026
+
1027
+ graphs_count = len(safe_dataset_graphs.get("columnSummaries", []))
1028
+ print(f"Graphs generated: {graphs_count}")
1029
+
1030
+ # 11. Return success
1031
+ return {
1032
+ "message": "Upload successful.",
1033
+ "filename": parquet_filename,
1034
+ "original_filename": file.filename,
1035
+ "user_id": user_id,
1036
+ "file_path": file_key,
1037
+ "file_url": file_url,
1038
+ "file_hash": file_hash,
1039
+ "source_file_type": file_ext[1:],
1040
+ "file_type": "parquet",
1041
+ "original_file_size_bytes": len(file_content),
1042
+ "parquet_file_size_bytes": parquet_size,
1043
+ "compression_ratio": f"{(1 - parquet_size/len(file_content))*100:.1f}%",
1044
+ "rows": len(df),
1045
+ "columns": len(df.columns),
1046
+ "metadata": safe_metadata,
1047
+ "data_sets_preview_graph": safe_dataset_graphs, # Separate key at root level
1048
+ "upload_dataset_vdb": vdb_success,
1049
+ "upload_dataset_pg": pg_success,
1050
+ "pg_details": pg_result,
1051
+ "graphs_generated": graphs_count
1052
+ }
1053
+
1054
+ except HTTPException:
1055
+ raise
1056
+ except Exception as e:
1057
+ print(f"Unexpected error: {e}")
1058
+ import traceback
1059
+ traceback.print_exc()
1060
+ raise HTTPException(status_code=500, detail=str(e))
1061
+ finally:
1062
+ # Clean up temp directories
1063
+ for d in (html_tmp_dir, bokeh_tmp_dir):
1064
+ if d and os.path.exists(d):
1065
+ shutil.rmtree(d, ignore_errors=True)
s3/s3.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
s3_viewer_backend/__pycache__/get_data_graphs.cpython-311.pyc ADDED
Binary file (2.96 kB). View file
 
s3_viewer_backend/__pycache__/s3_viewer.cpython-311.pyc CHANGED
Binary files a/s3_viewer_backend/__pycache__/s3_viewer.cpython-311.pyc and b/s3_viewer_backend/__pycache__/s3_viewer.cpython-311.pyc differ
 
s3_viewer_backend/get_data_graphs.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import List
5
+ import httpx
6
+
7
+ # --- Project Root Setup ---
8
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
9
+ if str(PROJECT_ROOT) not in sys.path:
10
+ sys.path.insert(0, str(PROJECT_ROOT))
11
+ from retrieve_secret import *
12
+ def print_table_headers():
13
+ # Connect to PostgreSQL using your decrypted credentials
14
+ conn = psycopg2.connect(
15
+ host=CONNECTIONS_HOST,
16
+ database=CONNECTIONS_DB,
17
+ user=CONNECTIONS_USER,
18
+ password=CONNECTIONS_PASS
19
+ )
20
+
21
+ cursor = conn.cursor()
22
+
23
+ # Query to fetch column names from the table
24
+ cursor.execute("""
25
+ SELECT column_name
26
+ FROM information_schema.columns
27
+ WHERE table_schema = 'public'
28
+ AND table_name = 'user_datasets_metadata'
29
+ ORDER BY ordinal_position;
30
+ """)
31
+
32
+ columns = cursor.fetchall()
33
+
34
+ print("Table Headers:")
35
+ for col in columns:
36
+ print(col[0])
37
+
38
+ cursor.close()
39
+ conn.close()
40
+
41
+
42
+ print_table_headers()
43
+
44
+
45
+
46
+ def data_sets_preview_graph_retrieve(user_id, file_path):
47
+ try:
48
+ conn = psycopg2.connect(
49
+ host=CONNECTIONS_HOST,
50
+ database=CONNECTIONS_DB,
51
+ user=CONNECTIONS_USER,
52
+ password=CONNECTIONS_PASS
53
+ )
54
+ cursor = conn.cursor()
55
+
56
+ query = """
57
+ SELECT data_sets_preview_graph
58
+ FROM public.user_datasets_metadata
59
+ WHERE user_id = %s
60
+ AND path = %s
61
+ LIMIT 1;
62
+ """
63
+
64
+ cursor.execute(query, (user_id, file_path))
65
+ result = cursor.fetchone()
66
+
67
+ cursor.close()
68
+ conn.close()
69
+
70
+ if result:
71
+ return result[0] # return the data_sets_preview_graph field
72
+ else:
73
+ return None
74
+
75
+ except Exception as e:
76
+ print(f"Error: {e}")
77
+ return None
78
+
79
+ # user_id = "18171ae4-134c-404a-a3d8-42fa8603dc3f"
80
+ # file_path = "18171ae4-134c-404a-a3d8-42fa8603dc3f/files/datasets/marketing_data_daily.parquet"
81
+ # data_sets_preview_graph = data_sets_preview_graph_retrieve(user_id, file_path)
82
+ # print(data_sets_preview_graph)
s3_viewer_backend/s3_viewer.py CHANGED
@@ -262,6 +262,8 @@ from botocore.exceptions import ClientError
262
  import urllib.parse
263
  from fastapi import APIRouter
264
  from retrieve_secret import *
 
 
265
  # === CONFIG ===
266
  # ENDPOINT_URL = "https://s3.us-west-1.idrivee2.com"
267
  # ACCESS_KEY = "rNuPBAQetemqpEeBospZ"
@@ -295,6 +297,10 @@ s3 = boto3.client(
295
  region_name=REGION_NAME
296
  )
297
 
 
 
 
 
298
  ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com" # For file URL generation
299
 
300
  s3_viewer_router = APIRouter(prefix="/s3_viewer", tags=["s3_viewer"])
@@ -464,13 +470,35 @@ def get_files():
464
  # )
465
  # return result
466
 
 
 
467
  @s3_viewer_router.get("/api/file")
468
- def get_file(file: str, request: Request):
469
- """Return preview content for CSV, text, PDF, images, Parquet, Office, etc."""
 
470
  ext = file.lower().split('.')[-1] if '.' in file else ''
471
- result = {"fileType": None, "fileContent": None, "filePreview": None}
472
 
473
- # === Special handling for Parquet ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  if ext == "parquet":
475
  try:
476
  df = read_parquet_from_s3(file)
@@ -486,32 +514,41 @@ def get_file(file: str, request: Request):
486
  )
487
  result["fileContent"] = f"Rows: {len(df)}, Columns: {len(df.columns)}"
488
  return result
 
489
  except Exception as e:
490
  raise HTTPException(status_code=500, detail=f"Error reading Parquet: {str(e)}")
491
 
492
- # === For all other file types: use raw bytes ===
493
- content = read_file_from_s3(file) # This returns bytes or None
 
 
494
  if not content:
495
  raise HTTPException(status_code=404, detail="File not found")
496
 
497
- # === Office Files (PPTX, DOCX, XLSX) ===
 
 
498
  office_extensions = ["pptx", "docx", "xlsx"]
499
  if ext in office_extensions:
500
  download_url = get_download_url(file, request)
501
- embed_url = f"https://view.officeapps.live.com/op/embed.aspx?src={urllib.parse.quote(download_url)}"
 
 
 
502
 
503
  result["fileType"] = ext
504
  result["filePreview"] = (
505
- f'<iframe src="{embed_url}" '
506
- 'width="100%" height="600px" frameborder="0" '
507
- 'style="border: none; max-width: 100%;"></iframe>'
508
  f'<p style="margin-top: 10px;"><small>'
509
  f'If preview fails, <a href="{download_url}" target="_blank">download {file}</a>.'
510
  f'</small></p>'
511
  )
512
  return result
513
 
514
- # === CSV ===
 
 
515
  if ext == "csv":
516
  try:
517
  df = pd.read_csv(StringIO(content.decode("utf-8")))
@@ -520,54 +557,70 @@ def get_file(file: str, request: Request):
520
  classes="table table-striped table-sm",
521
  index=False,
522
  border=0,
523
- justify="left"
524
  )
525
  except Exception as e:
526
  result["fileType"] = "csv-error"
527
  result["fileContent"] = f"Error reading CSV: {str(e)}"
528
  return result
529
 
530
- # === Text files ===
531
- elif ext in ["txt", "log", "json", "md", "py", "js", "css", "html", "xml"]:
 
 
532
  result["fileType"] = "text"
533
  try:
534
  text_content = content.decode("utf-8", errors="replace")
535
  result["fileContent"] = text_content
 
536
  if len(text_content) < 10000:
537
- result["filePreview"] = f"<pre style='max-height: 500px; overflow: auto; background: #f4f4f4; padding: 10px; border-radius: 5px;'>{text_content}</pre>"
 
 
 
 
538
  except:
539
  result["fileContent"] = "Unable to decode text file."
540
  return result
541
 
542
- # === Images ===
543
- elif ext in ["png", "jpg", "jpeg", "gif", "webp", "bmp", "ico"]:
 
 
544
  b64_data = base64.b64encode(content).decode("utf-8")
545
  mime_type = "jpeg" if ext == "jpg" else ext
 
546
  result["fileType"] = "image"
547
  result["filePreview"] = (
548
  f'<img src="data:image/{mime_type};base64,{b64_data}" '
549
- 'style="max-width:100%; height:auto; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);"/>'
 
550
  )
551
  return result
552
 
553
- # === SVG ===
554
- elif ext == "svg":
 
 
555
  try:
556
  svg_text = content.decode("utf-8")
557
  if "<script" in svg_text.lower():
558
- result["fileType"] = "svg"
559
- result["filePreview"] = "<p>SVG contains scripts. Preview blocked for security.</p>"
560
  else:
561
- result["fileType"] = "svg"
562
  result["filePreview"] = svg_text
563
- except:
564
  result["fileType"] = "svg"
 
565
  result["filePreview"] = "<p>Invalid SVG content.</p>"
 
566
  return result
567
 
568
- # === PDF ===
569
- elif ext == "pdf":
 
 
570
  b64_data = base64.b64encode(content).decode("utf-8")
 
571
  result["fileType"] = "pdf"
572
  result["filePreview"] = (
573
  f'<iframe src="data:application/pdf;base64,{b64_data}" '
@@ -575,30 +628,34 @@ def get_file(file: str, request: Request):
575
  )
576
  return result
577
 
578
- # === HTML (open in new tab) ===
579
- elif ext in ["html", "htm"]:
 
 
580
  result["fileType"] = "html"
581
  result["filePreview"] = (
582
- f'<p><a href="/s3_viewer/api/html?file={urllib.parse.quote(file)}" target="_blank" '
583
- f'style="color: #007bff; text-decoration: underline;">'
584
  f'Open {file} in new tab</a></p>'
585
  )
586
  return result
587
 
588
- # === Fallback: Other files ===
589
- else:
590
- download_url = get_download_url(file, request)
591
- result["fileType"] = "other"
592
- result["fileContent"] = f"File type: .{ext}\nSize: {len(content):,} bytes"
593
- result["filePreview"] = (
594
- f'<div style="padding: 20px; text-align: center; background: #f8f9fa; border-radius: 8px;">'
595
- f'<p><strong>Preview not available</strong> for <code>.{ext}</code> files.</p>'
596
- f'<a href="{download_url}" target="_blank" download '
597
- f'style="display: inline-block; padding: 10px 20px; background: #007bff; color: white; '
598
- f'text-decoration: none; border-radius: 5px; margin-top: 10px;">'
599
- f'Download {file}</a></div>'
600
- )
601
- return result
 
 
602
 
603
  # @s3_viewer_router.get("/api/html")
604
  # def view_html(file: str):
 
262
  import urllib.parse
263
  from fastapi import APIRouter
264
  from retrieve_secret import *
265
+ from pydantic import BaseModel
266
+
267
  # === CONFIG ===
268
  # ENDPOINT_URL = "https://s3.us-west-1.idrivee2.com"
269
  # ACCESS_KEY = "rNuPBAQetemqpEeBospZ"
 
297
  region_name=REGION_NAME
298
  )
299
 
300
+ class FileRequest(BaseModel):
301
+ file: str
302
+ user_id: str | None = None
303
+
304
  ENDPOINT_URL = f"https://s3.{REGION_NAME}.amazonaws.com" # For file URL generation
305
 
306
  s3_viewer_router = APIRouter(prefix="/s3_viewer", tags=["s3_viewer"])
 
470
  # )
471
  # return result
472
 
473
+ from s3_viewer_backend.get_data_graphs import data_sets_preview_graph_retrieve
474
+
475
  @s3_viewer_router.get("/api/file")
476
+ def get_file(file: str, user_id: str | None = None, request: Request = None):
477
+ """Return preview content for CSV, text, PDF, images, Parquet, Office files — AND dataset preview graph."""
478
+
479
  ext = file.lower().split('.')[-1] if '.' in file else ''
 
480
 
481
+ # Base Response
482
+ result = {
483
+ "user_id": user_id, # <-- returned in output
484
+ "fileType": None,
485
+ "fileContent": None,
486
+ "filePreview": None,
487
+ "data_sets_preview": None # <-- preview graph from DB
488
+ }
489
+
490
+ # ---------------------------------------------------------
491
+ # Get Graph Preview from DB
492
+ # ---------------------------------------------------------
493
+ try:
494
+ if user_id:
495
+ result["data_sets_preview"] = data_sets_preview_graph_retrieve(user_id, file)
496
+ except Exception as e:
497
+ result["data_sets_preview"] = f"Error retrieving graph: {str(e)}"
498
+
499
+ # ---------------------------------------------------------
500
+ # PARQUET
501
+ # ---------------------------------------------------------
502
  if ext == "parquet":
503
  try:
504
  df = read_parquet_from_s3(file)
 
514
  )
515
  result["fileContent"] = f"Rows: {len(df)}, Columns: {len(df.columns)}"
516
  return result
517
+
518
  except Exception as e:
519
  raise HTTPException(status_code=500, detail=f"Error reading Parquet: {str(e)}")
520
 
521
+ # ---------------------------------------------------------
522
+ # Read raw file from S3
523
+ # ---------------------------------------------------------
524
+ content = read_file_from_s3(file)
525
  if not content:
526
  raise HTTPException(status_code=404, detail="File not found")
527
 
528
+ # ---------------------------------------------------------
529
+ # Office Files
530
+ # ---------------------------------------------------------
531
  office_extensions = ["pptx", "docx", "xlsx"]
532
  if ext in office_extensions:
533
  download_url = get_download_url(file, request)
534
+ embed_url = (
535
+ "https://view.officeapps.live.com/op/embed.aspx?src="
536
+ + urllib.parse.quote(download_url)
537
+ )
538
 
539
  result["fileType"] = ext
540
  result["filePreview"] = (
541
+ f'<iframe src="{embed_url}" width="100%" height="600px" '
542
+ 'frameborder="0" style="border: none; max-width: 100%;"></iframe>'
 
543
  f'<p style="margin-top: 10px;"><small>'
544
  f'If preview fails, <a href="{download_url}" target="_blank">download {file}</a>.'
545
  f'</small></p>'
546
  )
547
  return result
548
 
549
+ # ---------------------------------------------------------
550
+ # CSV
551
+ # ---------------------------------------------------------
552
  if ext == "csv":
553
  try:
554
  df = pd.read_csv(StringIO(content.decode("utf-8")))
 
557
  classes="table table-striped table-sm",
558
  index=False,
559
  border=0,
560
+ justify="left",
561
  )
562
  except Exception as e:
563
  result["fileType"] = "csv-error"
564
  result["fileContent"] = f"Error reading CSV: {str(e)}"
565
  return result
566
 
567
+ # ---------------------------------------------------------
568
+ # Text Files
569
+ # ---------------------------------------------------------
570
+ if ext in ["txt", "log", "json", "md", "py", "js", "css", "html", "xml"]:
571
  result["fileType"] = "text"
572
  try:
573
  text_content = content.decode("utf-8", errors="replace")
574
  result["fileContent"] = text_content
575
+
576
  if len(text_content) < 10000:
577
+ result["filePreview"] = (
578
+ "<pre style='max-height: 500px; overflow: auto;"
579
+ " background: #f4f4f4; padding: 10px; border-radius: 5px;'>"
580
+ f"{text_content}</pre>"
581
+ )
582
  except:
583
  result["fileContent"] = "Unable to decode text file."
584
  return result
585
 
586
+ # ---------------------------------------------------------
587
+ # Images
588
+ # ---------------------------------------------------------
589
+ if ext in ["png", "jpg", "jpeg", "gif", "webp", "bmp", "ico"]:
590
  b64_data = base64.b64encode(content).decode("utf-8")
591
  mime_type = "jpeg" if ext == "jpg" else ext
592
+
593
  result["fileType"] = "image"
594
  result["filePreview"] = (
595
  f'<img src="data:image/{mime_type};base64,{b64_data}" '
596
+ 'style="max-width:100%; height:auto; border-radius: 8px; '
597
+ 'box-shadow: 0 2px 8px rgba(0,0,0,0.1);" />'
598
  )
599
  return result
600
 
601
+ # ---------------------------------------------------------
602
+ # SVG
603
+ # ---------------------------------------------------------
604
+ if ext == "svg":
605
  try:
606
  svg_text = content.decode("utf-8")
607
  if "<script" in svg_text.lower():
608
+ result["filePreview"] = "<p>SVG contains scripts. Preview blocked.</p>"
 
609
  else:
 
610
  result["filePreview"] = svg_text
611
+
612
  result["fileType"] = "svg"
613
+ except:
614
  result["filePreview"] = "<p>Invalid SVG content.</p>"
615
+ result["fileType"] = "svg"
616
  return result
617
 
618
+ # ---------------------------------------------------------
619
+ # PDF
620
+ # ---------------------------------------------------------
621
+ if ext == "pdf":
622
  b64_data = base64.b64encode(content).decode("utf-8")
623
+
624
  result["fileType"] = "pdf"
625
  result["filePreview"] = (
626
  f'<iframe src="data:application/pdf;base64,{b64_data}" '
 
628
  )
629
  return result
630
 
631
+ # ---------------------------------------------------------
632
+ # HTML Links
633
+ # ---------------------------------------------------------
634
+ if ext in ["html", "htm"]:
635
  result["fileType"] = "html"
636
  result["filePreview"] = (
637
+ f'<p><a href="/s3_viewer/api/html?file={urllib.parse.quote(file)}" '
638
+ 'target="_blank" style="color: #007bff; text-decoration: underline;">'
639
  f'Open {file} in new tab</a></p>'
640
  )
641
  return result
642
 
643
+ # ---------------------------------------------------------
644
+ # Fallback for unsupported types
645
+ # ---------------------------------------------------------
646
+ download_url = get_download_url(file, request)
647
+ result["fileType"] = "other"
648
+ result["fileContent"] = f"File type: .{ext}\nSize: {len(content):,} bytes"
649
+ result["filePreview"] = (
650
+ '<div style="padding: 20px; text-align: center; background: #f8f9fa; '
651
+ 'border-radius: 8px;">'
652
+ f'<p><strong>Preview not available</strong> for <code>.{ext}</code> files.</p>'
653
+ f'<a href="{download_url}" target="_blank" download '
654
+ 'style="display: inline-block; padding: 10px 20px; background: #007bff; '
655
+ 'color: white; text-decoration: none; border-radius: 5px; margin-top: 10px;">'
656
+ f'Download {file}</a></div>'
657
+ )
658
+ return result
659
 
660
  # @s3_viewer_router.get("/api/html")
661
  # def view_html(file: str):