File size: 31,425 Bytes
c0d6e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313fe76
c0d6e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3d3637
 
 
 
c0d6e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
import os
import logging
import json
import uuid
from typing import Dict, List, Optional, Any, Union

from datetime import datetime

# FastAPI and related imports
from fastapi import (
    FastAPI,
    WebSocket,
    WebSocketDisconnect,
    HTTPException,
    Body,
    Query,
    Depends
)
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
from dotenv import load_dotenv

# LangChain / RAG Pipeline Imports (placeholder imports—adjust for your project)
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.tools import tool
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from bs4 import BeautifulSoup
import requests

from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends

# Supabase
from supabase import create_client, Client

###############################################################################
#                            ENV & LOGGING SETUP
###############################################################################
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY")
SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)

if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY in environment!")
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

###############################################################################
#                  SUPABASE CREDENTIALS & CLIENT INITIALIZATION
###############################################################################

SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY")
SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")

supabase_client: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
supabase_admin: Client = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY)

###############################################################################
#                      OPTIONAL: CREATE TABLES / SCHEMA
###############################################################################
def create_db_schema() -> None:
    """
    You can run this function ONCE in a safe admin environment to create
    the necessary tables in your Supabase Postgres database (if they do not exist).
    """
    schema_sql = """
    -- Enable UUID generation if not enabled
    CREATE EXTENSION IF NOT EXISTS "uuid-ossp";

    CREATE TABLE IF NOT EXISTS public.users (
      id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
      created_at timestamp with time zone DEFAULT now(),
      email text UNIQUE NOT NULL,
      password_hash text,
      full_name text,
      last_login timestamp with time zone,
      role text DEFAULT 'user'
    );

    CREATE TABLE IF NOT EXISTS public.chats (
      chat_id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
      user_id uuid REFERENCES public.users (id) ON DELETE CASCADE,
      created_at timestamp with time zone DEFAULT now(),
      title text,
      last_updated timestamp with time zone DEFAULT now()
    );

    CREATE TABLE IF NOT EXISTS public.chat_session (
      session_id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
      chat_id uuid REFERENCES public.chats (chat_id) ON DELETE CASCADE,
      created_at timestamp with time zone DEFAULT now(),
      updated_at timestamp with time zone DEFAULT now(),
      content jsonb DEFAULT '{}'::jsonb
    );

    CREATE TABLE IF NOT EXISTS public.logs (
      log_id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
      session_id uuid REFERENCES public.chat_session (session_id) ON DELETE CASCADE,
      timestamp timestamp with time zone DEFAULT now(),
      event_type text,
      details jsonb DEFAULT '{}'::jsonb
    );

    CREATE TABLE IF NOT EXISTS public.ai_thought_table (
      id uuid DEFAULT uuid_generate_v4() PRIMARY KEY,
      created_at timestamp with time zone DEFAULT now(),
      session_id uuid REFERENCES public.chat_session (session_id) ON DELETE CASCADE,
      thought_process text,
      decision_making jsonb DEFAULT '{}'::jsonb
    );
    """

    logging.info("Schema creation SQL:\n%s", schema_sql)
    # You can run this SQL in Supabase's SQL Editor, or use an RPC if you have one:
    # supabase_admin.rpc('execute_sql', {'q': schema_sql}).execute()
    # Or manually run it in your project's SQL editor.
    pass

###############################################################################
#                               FASTAPI APP
###############################################################################
app = FastAPI(
    title="RAG-GENAI-Women",
    version="1.0.0",
    description=(
        "A production-ready pipeline with session-based JSON storage, plus "
        "auth endpoints for SignUp, Login, and more. "
        "Supports multiple concurrent WebSocket connections (one per session)."
    )
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Restrict in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Add security scheme
security = HTTPBearer()
###############################################################################
#                          LLM & VECTOR STORE SETUP
###############################################################################
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-large")
llm = ChatOpenAI(model="gpt-4o")  # Example placeholder name
llm_decision_maker = ChatOpenAI(model="gpt-4o-mini")

vector_store = Chroma(
    persist_directory="./chroma_db",
    embedding_function=embeddings_model
)

def get_time_date() -> str:
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

def get_country_from_ip() -> str:
    # Stub; in production, do a real IP lookup
    return "India"

###############################################################################
#                                 WEB SEARCH TOOL
###############################################################################
@tool
def web_search_tool(query: str) -> Dict[str, Any]:
    """
    Perform a web search and return a single dictionary:
       {"results": [...], "count": <int>}
    """
    from googlesearch import search

    results = []
    try:
        for url in search(query, num_results=3):
            try:
                resp = requests.get(url, timeout=10)
                soup = BeautifulSoup(resp.text, "html.parser")
                snippet = soup.get_text()[:1000]
                results.append({"url": url, "content": snippet})
            except Exception as e:
                logging.exception("Error fetching content from: %s", url)
                results.append({"url": url, "content": f"Error: {str(e)}"})
    except Exception as e:
        logging.error("Error performing search: %s", e)

    return {"results": results, "count": len(results)}

###############################################################################
#                                RAG PIPELINE
###############################################################################
class State(TypedDict):
    question: str
    retrieved_context: List[Document]
    demographic_context: str
    web_search_needed: int
    web_search_results: List[dict]
    final_answer: str
    tone: str

def add_demographic_context(state: State):
    country = get_country_from_ip()
    timestamp = get_time_date()
    demo = f"User from {country} at {timestamp}"
    logging.info(f"[add_demographic_context] {demo}")
    return {"demographic_context": demo}

def retrieve(state: State):
    logging.info("[retrieve] Searching vector store...")
    user_query = state["question"]
    docs = vector_store.similarity_search(user_query)
    combined_text = "\n\n".join(doc.page_content for doc in docs)

    tone = state.get("tone", "detailed")
    sys_msg = (
        "You are an assistant extracting key points. Focus on relevant details. Think like a lawyer and search for relevant details."
        if tone != "casual" else
        "You are an assistant extracting key points in a conversational manner. Think like a lawyer and search for relevant details."
    )

    prompt = [
        {"role": "system", "content": sys_msg},
        {
            "role": "user",
            "content": (
                f"Query:\n{user_query}\n\nDocs:\n{combined_text}\n"
                "Extract relevant points."
            )
        }
    ]
    resp = llm.invoke(prompt)
    extracted = resp.content.strip()

    return {"retrieved_context": [Document(page_content=extracted, metadata={"source": "filtered"})]}

def decide_web_search(state: State):
    logging.info("[decide_web_search]")
    retrieved_text = "\n\n".join(doc.page_content for doc in state["retrieved_context"])

    messages = [
        {
            "role": "system",
            "content": (
                "You are a decision-making assistant. "
                "Respond strictly with '1' if a web search is required, or '0' if not."
            )
        },
        {
            "role": "user",
            "content": f"Question:\n{state['question']}\n\nContext:\n{retrieved_text}"
        },
    ]

    response = llm_decision_maker.invoke(messages)
    decision = response.content.strip()

    logging.info(f"[decide_web_search] LLM decision raw: {decision}")

    try:
        return {"web_search_needed": int(decision)}
    except ValueError:
        logging.error(f"Invalid decision response: {decision}")
        raise ValueError(f"Unexpected LLM response for web search decision: {decision}")

def perform_web_search(state: State):
    need_search = state.get("web_search_needed", 0)
    if need_search == 1:
        logging.info("[perform_web_search] Searching the web...")
        query = f"{state['question']} ({state['demographic_context']})"
        search_data = web_search_tool.invoke(query)  # returns a dict
        structured_results = search_data["results"]
        summarized_results = []
        for r in structured_results:
            c = r["content"]
            sum_prompt = [
                {"role": "system", "content": "Summarize the content with short citation."},
                {"role": "user", "content": f"{c}\nURL: {r['url']}"}
            ]
            sum_resp = llm.invoke(sum_prompt)
            summarized_results.append({
                "url": r["url"],
                "summary": sum_resp.content.strip()
            })
        return {"web_search_results": summarized_results}
    else:
        logging.info("[perform_web_search] Skipping web search...")
        return {"web_search_results": []}

def consolidate(state: State):
    logging.info("[consolidate] Generating final answer...")
    retrieved_text = "\n\n".join(doc.page_content for doc in state["retrieved_context"])
    web_data = state.get("web_search_results", [])

    sources_text = "\n".join(
        f"URL: {r['url']}\nSummary: {r['summary']}" for r in web_data
    )
    tone = state.get("tone", "detailed")
    sys_msg = (
        "You are a precise assistant. Combine context and results into a final answer."
        if tone != "casual" else
        "You are a friendly assistant. Combine context and results in a final manner."
    )

    final_prompt = [
        {"role": "system", "content": sys_msg},
        {
            "role": "user",
            "content": (
                f"Question:\n{state['question']}\n\n"
                f"Retrieved:\n{retrieved_text}\n\n"
                f"Web:\n{sources_text}\n\n"
                "Give a comprehensive final answer."
            )
        }
    ]
    resp = llm.invoke(final_prompt)
    raw_ans = resp.content.strip()

    # Summarize for chat
    summ_prompt = [
        {
            "role": "system",
            "content": "Provide a concise version of the answer, preserving key details."
        },
        {
            "role": "user",
            "content": raw_ans
        }
    ]
    s_resp = llm.invoke(summ_prompt)
    chat_ans = s_resp.content.strip()

    final = {
        "crunched_summary": chat_ans,
        "full_answer": raw_ans,
        "sources": web_data if web_data else None,
        "source_type": (
            "Web + Retrieved" if web_data and retrieved_text
            else "Web" if web_data
            else "Retrieved"
        )
    }
    return {"final_answer": final}

###############################################################################
#                         PIPELINE GRAPH BUILD
###############################################################################
graph_builder = StateGraph(State).add_sequence([
    add_demographic_context,
    retrieve,
    decide_web_search,
    perform_web_search,
    consolidate
])
graph_builder.add_edge(START, "add_demographic_context")
graph_builder.add_edge("add_demographic_context", "retrieve")
graph_builder.add_edge("retrieve", "decide_web_search")
graph_builder.add_edge("decide_web_search", "perform_web_search")
graph_builder.add_edge("perform_web_search", "consolidate")
graph_builder.add_edge("consolidate", END)

pipeline_graph = graph_builder.compile()

###############################################################################
#                         SESSION-BASED JSON STORAGE
###############################################################################
SESSIONS_DIR = "sessions_data"
# os.makedirs(SESSIONS_DIR, exist_ok=True)

def generate_session_id() -> str:
    return str(uuid.uuid4())

def get_session_file(session_id: str) -> str:
    return os.path.join(SESSIONS_DIR, f"{session_id}.json")

def load_session_from_json(session_id: str) -> dict:
    """Load or create session data from JSON."""
    path = get_session_file(session_id)
    if os.path.exists(path):
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    else:
        data = {
            "session_id": session_id,
            "started_at": get_time_date(),
            "messages": []
        }
        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2)
        return data

def save_session_to_json(session_data: dict):
    session_id = session_data["session_id"]
    path = get_session_file(session_id)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(session_data, f, indent=2)

def append_message(session_id: str, role: str, content: str):
    data = load_session_from_json(session_id)
    data["messages"].append({
        "role": role,
        "content": content,
        "timestamp": get_time_date()
    })
    save_session_to_json(data)

###############################################################################
#                              AUTH & USER MODELS
###############################################################################
class SignupRequest(BaseModel):
    email: EmailStr
    password: str
    full_name: Optional[str] = None

class SignupResponse(BaseModel):
    user_id: Optional[str]
    message: str

class LoginRequest(BaseModel):
    email: EmailStr
    password: str

class LoginResponse(BaseModel):
    access_token: Optional[str]
    token_type: str = "bearer"
    user_id: Optional[str]
    message: str

class LogoutResponse(BaseModel):
    message: str

class Identity(BaseModel):
    provider: str
    identity_id: str
    created_at: Union[datetime, str]
    last_sign_in_at: Union[datetime, str]

class UserProfile(BaseModel):
    user_id: str
    email: str
    full_name: Optional[str]
    role: str
    created_at: datetime
    updated_at: Optional[datetime]
    last_sign_in_at: Optional[datetime]
    email_verified: bool
    phone_verified: bool
    is_anonymous: bool
    app_metadata: Dict[str, Union[str, List[str]]]
    user_metadata: Dict[str, Union[str, bool]]
    identities: List[Identity]


###############################################################################
#                              HTTP MODELS
###############################################################################
class AskRequest(BaseModel):
    user_input: str
    tone: Optional[str] = "detailed"

class AskResponse(BaseModel):
    session_id: str
    message: str

###############################################################################
#                             HTTP AUTH ENDPOINTS
###############################################################################
@app.get("/")
def read_root():
    return {"message": "Hello from FastAPI on Hugging Face Spaces!"}

@app.post("/auth/signup", response_model=SignupResponse)
def signup(payload: SignupRequest):
    """
    Sign up a new user using Supabase Auth.
    Optionally store extra info (e.g., full_name) in your custom 'users' table.
    """
    # 1) Use Supabase Auth to create the user
    try:
        result = supabase_client.auth.sign_up(
            {
                "email": payload.email,
                "password": payload.password
            }
        )
    except Exception as e:
        logging.exception("[signup] Error from Supabase Auth sign_up")
        return SignupResponse(user_id=None, message=f"Sign up failed: {str(e)}")

    if result.user is None:
        # Possibly means "Confirm email" is enabled, user needs to verify
        return SignupResponse(
            user_id=None,
            message="User created, but email confirmation required."
        )

    # 2) The user is created in supabase.auth. We can optionally store extra data
    user_id = result.user.id
    full_name = payload.full_name if payload.full_name else ""
    now = datetime.utcnow()

    # Attempt to store in our custom 'users' table
    try:
        insert_res = supabase_admin.table("users").insert({
            "id": user_id,
            "email": payload.email,
            "password_hash": "N/A (Using Supabase Auth)",
            "full_name": full_name,
            "created_at": now.isoformat(),
            "last_login": None,
            "role": "user"
        }).execute()
        logging.info("[signup] Inserted custom user record: %s", insert_res.data)
    except Exception as e:
        logging.exception("[signup] Error inserting into 'users' table")

    return SignupResponse(user_id=user_id, message="Sign up successful.")

@app.post("/auth/login", response_model=LoginResponse)
def login(payload: LoginRequest):
    """
    Log in an existing user with Supabase Auth. 
    Return the access_token, which you can store on client side for usage, 
    or rely on same-site cookies if you have it configured.
    """
    try:
        result = supabase_client.auth.sign_in_with_password(
            {
                "email": payload.email,
                "password": payload.password
            }
        )
        if result.user is None:
            return LoginResponse(
                access_token=None,
                user_id=None,
                message="Login failed: invalid credentials or user not confirmed."
            )

        user_id = result.user.id
        access_token = result.session.access_token if result.session else None

        # We can track "last_login" in our custom table:
        now = datetime.utcnow()
        try:
            supabase_admin.table("users").update({
                "last_login": now.isoformat()
            }).eq("id", user_id).execute()
        except Exception as e:
            logging.exception("[login] Error updating last_login in 'users' table")

        return LoginResponse(
            access_token=access_token,
            user_id=user_id,
            message="Login success."
        )
    except Exception as e:
        logging.exception("[login] Error from Supabase Auth sign_in_with_password")
        return LoginResponse(
            access_token=None,
            user_id=None,
            message=f"Login error: {str(e)}"
        )

@app.post("/auth/logout", response_model=LogoutResponse)
def logout():
    """
    Invalidate the user's session if you are storing it on the server 
    or using persistent session management. For token-based approach,
    you can have the client discard the token and possibly call 
    supabase_client.auth.sign_out() as well.
    """
    try:
        # This will revoke the refresh token from Supabase's perspective
        supabase_client.auth.sign_out()
        return LogoutResponse(message="Logout successful.")
    except Exception as e:
        logging.exception("[logout] Error from Supabase Auth sign_out")
        raise HTTPException(status_code=500, detail="Logout failed.")


@app.get("/auth/me", response_model=UserProfile)
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
    """
    Retrieve info about the currently logged-in user.
    """
    try:
        # Extract access token from Authorization header
        access_token = credentials.credentials

        # Retrieve user details using the access token
        user_response = supabase_client.auth.get_user(access_token)
        if not user_response or not user_response.user:
            raise HTTPException(status_code=401, detail="User not authenticated.")

        user = user_response.user

        # Optionally fetch additional data from your custom `users` table
        res = supabase_client.table("users").select("*").eq("id", user.id).single().execute()
        record = res.data

        # Construct the UserProfile response
        return UserProfile(
            user_id=user.id,
            email=user.email,
            full_name=record.get("full_name") if record else None,
            role=user.role,
            created_at=user.created_at,
            updated_at=user.updated_at,
            last_sign_in_at=user.last_sign_in_at,
            email_verified=user.user_metadata.get("email_verified", False),
            phone_verified=user.user_metadata.get("phone_verified", False),
            is_anonymous=user.is_anonymous,
            app_metadata=user.app_metadata,
            user_metadata=user.user_metadata,
            identities=[
                Identity(
                    provider=identity.provider,
                    identity_id=identity.identity_id,
                    created_at=str(identity.created_at) if isinstance(identity.created_at, datetime) else identity.created_at,
                    last_sign_in_at=str(identity.last_sign_in_at) if isinstance(identity.last_sign_in_at, datetime) else identity.last_sign_in_at,
                )
                for identity in user.identities
            ] if user.identities else []
        )
    except Exception as e:
        logging.exception("[get_current_user] Error retrieving user info")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/auth/confirm")
def confirm_email(
    access_token: str = Query(...), 
    refresh_token: str = Query(...), 
    expires_in: int = Query(...), 
    token_type: str = Query(...)
):
    """
    Endpoint to handle confirmation links sent via email.
    """
    try:
        # Use Supabase client to retrieve and confirm the user
        result = supabase_client.auth.get_user(access_token)
        if result.user:
            return {"status": "success", "message": "Email confirmed successfully.", "user": result.user}
        else:
            raise HTTPException(status_code=400, detail="Invalid or expired confirmation link.")
    except Exception as e:
        logging.exception("[confirm_email] Error during confirmation")
        raise HTTPException(status_code=500, detail=str(e))


###############################################################################
#                             HTTP ENDPOINTS
###############################################################################
@app.get("/health")
def health_check():
    """Simple health check endpoint."""
    return {"status": "ok", "message": "Service is healthy."}

@app.post("/ask", response_model=AskResponse)
def ask_endpoint(payload: AskRequest):
    """
    Optional endpoint to create a session or store the first user message
    before switching to WebSockets.
    """
    session_id = generate_session_id()
    user_input = payload.user_input
    append_message(session_id, "user", user_input)

    return AskResponse(
        session_id=session_id,
        message="Session created. Connect via WS to continue."
    )

@app.post("/reset")
def reset_session(session_id: str = Body(..., embed=True)):
    """
    Deletes the session JSON file, effectively resetting the conversation.
    """
    path = get_session_file(session_id)
    if os.path.exists(path):
        os.remove(path)
        return {"status": "ok", "message": f"Session {session_id} reset."}
    else:
        raise HTTPException(status_code=404, detail="Session not found.")

###############################################################################
#                          WEBSOCKET CONCURRENCY
###############################################################################
class ConnectionManager:
    """
    Manages EXACTLY ONE active WebSocket per session_id.
    If a new WebSocket for the same session_id arrives, 
    it closes the old connection first. 
    """
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}

    async def connect(self, session_id: str, websocket: WebSocket):
        # If there's already an active socket for this session, close it
        if session_id in self.active_connections:
            old_ws = self.active_connections[session_id]
            logging.info(f"[WS] Closing old connection for session {session_id} to allow new one.")
            await old_ws.close(code=4000, reason="Replaced by a new connection")

        logging.info(f"[WS] Accepting WebSocket for session: {session_id}")
        await websocket.accept()

        self.active_connections[session_id] = websocket
        logging.info(f"[WS] Session {session_id} connected. "
                     f"Total active sessions: {len(self.active_connections)}")

    def disconnect(self, session_id: str, websocket: WebSocket):
        stored_ws = self.active_connections.get(session_id)
        if stored_ws is websocket:
            del self.active_connections[session_id]
            logging.info(f"[WS] Session {session_id} disconnected. "
                         f"Remaining active sessions: {len(self.active_connections)}")

    async def send_json(self, session_id: str, data: dict):
        ws = self.active_connections.get(session_id)
        if ws is not None:
            await ws.send_json(data)

manager = ConnectionManager()

@app.websocket("/ws")
async def websocket_endpoint(
    websocket: WebSocket,
    session_id: Optional[str] = Query(None),
    tone: str = Query("detailed")
):
    """
    WebSocket endpoint. 
    - The user can pass `session_id` and `tone` as query parameters, e.g.:
        ws://localhost:8000/ws?session_id=abc-123&tone=casual
    - Or omit `session_id` to generate one automatically.
    - Each message from client must be JSON with {"user_input": "..."}.
    """
    if not session_id:
        session_id = generate_session_id()
        logging.info(f"[WS] No session_id provided. Created new: {session_id}")

    await manager.connect(session_id, websocket)

    while True:
        try:
            data = await websocket.receive_json()
            user_input = data.get("user_input", "")
            append_message(session_id, "user", user_input)

            session_data = load_session_from_json(session_id)
            conversation_text = ""
            for msg in session_data["messages"]:
                role_name = msg["role"].capitalize()
                conversation_text += f"{role_name}: {msg['content']}\n"

            chain_state = {
                "question": conversation_text,
                "tone": tone
            }

            await manager.send_json(session_id, {
                "type": "status",
                "message": "Starting pipeline..."
            })

            try:
                async for step_result in pipeline_graph.astream(chain_state, stream_mode="values"):
                    if "demographic_context" in step_result:
                        await manager.send_json(session_id, {
                            "type": "status",
                            "message": f"Demographic: {step_result['demographic_context']}"
                        })
                    if "retrieved_context" in step_result:
                        excerpt = step_result["retrieved_context"][0].page_content[:60]
                        await manager.send_json(session_id, {
                            "type": "status",
                            "message": f"Retrieved context: {excerpt}..."
                        })
                    if "web_search_needed" in step_result:
                        await manager.send_json(session_id, {
                            "type": "status",
                            "message": f"Web search needed = {step_result['web_search_needed']}"
                        })
                    if "web_search_results" in step_result:
                        count = len(step_result["web_search_results"])
                        await manager.send_json(session_id, {
                            "type": "status",
                            "message": f"Web search returned {count} results."
                        })
                    if "final_answer" in step_result:
                        final_ans = step_result["final_answer"]
                        short_answer = final_ans["crunched_summary"]
                        append_message(session_id, "assistant", short_answer)

                        await manager.send_json(session_id, {
                            "type": "final_answer",
                            "short_answer": short_answer,
                            "full_answer": final_ans["full_answer"],
                            "sources": final_ans["sources"],
                            "source_type": final_ans["source_type"]
                        })

            except Exception as e:
                logging.exception("[WS] Error during pipeline streaming.")
                await manager.send_json(session_id, {
                    "type": "error",
                    "message": str(e)
                })

        except WebSocketDisconnect:
            logging.info(f"[WS] Client disconnected for session {session_id}")
            manager.disconnect(session_id, websocket)
            break
        except Exception as e:
            logging.exception("[WS] Error reading JSON from WebSocket.")
            await manager.send_json(session_id, {
                "type": "error",
                "message": str(e)
            })
            # Not disconnecting immediately—client may continue with valid input

###############################################################################
#                         LOCAL DEV ENTRY POINT
###############################################################################
# if __name__ == "__main__":
#     import uvicorn
#     # Uncomment if you want to log out or run the DDL
#     uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)