Spaces:
Sleeping
Sleeping
| # app.py - Secure Version | |
| import os | |
| import uuid | |
| import shutil | |
| import tempfile | |
| import hashlib | |
| import hmac | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Depends, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from jose import jwt, JWTError | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| from rag import retrieve_context | |
| from login_signup import router as auth_router | |
| from agent import react_graph | |
| from deep_research import SummaryStateInput, graph as deep_research_graph | |
| from audiobot import handle_audio | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import HumanMessage | |
| # Environment variables | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key") | |
| ALGORITHM = "HS256" | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| DB_NAME = os.getenv("DB_NAME", "myapp") | |
| # MongoDB setup | |
| _client = AsyncIOMotorClient(MONGO_URI) | |
| _db = _client[DB_NAME] | |
| _users = _db["users"] | |
| _threads = _db["user_threads"] | |
| llm_rag = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", temperature=0.5, google_api_key=GOOGLE_API_KEY | |
| ) | |
| prompt = PromptTemplate.from_template( | |
| "you are rag bot tasked to answer questions on aakash jammula " | |
| "(document contains my information/resume) and be engaging and only respond in text " | |
| "like a human. Use context only when they ask about me (aakash jammula).\n\n" | |
| "Context:\n{context}\n\nQuestion:\n{question}\nAnswer:" | |
| ) | |
| BASE_TMP = tempfile.gettempdir() | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| app.include_router(auth_router, prefix="/auth", tags=["auth"]) | |
| security = HTTPBearer(auto_error=False) | |
| def generate_secure_thread_id(user_id: str, salt: str = None) -> str: | |
| """Generate a secure, non-reversible thread ID""" | |
| if salt is None: | |
| salt = SECRET_KEY | |
| # Create HMAC-based thread ID that can't be reverse-engineered | |
| message = f"{user_id}:{datetime.utcnow().date().isoformat()}" | |
| thread_hash = hmac.new( | |
| salt.encode(), | |
| message.encode(), | |
| hashlib.sha256 | |
| ).hexdigest()[:16] | |
| return f"thread_{thread_hash}" | |
| async def get_current_user_id(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Optional[str]: | |
| """Extract and validate user ID from JWT token""" | |
| if not credentials: | |
| return None | |
| try: | |
| payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) | |
| user_id: str = payload.get("sub") | |
| exp = payload.get("exp") | |
| # Check token expiration | |
| if exp and datetime.utcnow().timestamp() > exp: | |
| return None | |
| # Verify user still exists | |
| if user_id and await _users.find_one({"_id": user_id}): | |
| return user_id | |
| except JWTError: | |
| pass | |
| return None | |
| def get_guest_thread_id(request: Request) -> str: | |
| """Generate consistent guest thread ID based on session""" | |
| # Combine IP and User-Agent for session consistency | |
| client_ip = request.client.host | |
| user_agent = request.headers.get("user-agent", "") | |
| session_data = f"{client_ip}:{user_agent}" | |
| session_hash = hashlib.sha256( | |
| f"{SECRET_KEY}:{session_data}".encode() | |
| ).hexdigest()[:16] | |
| return f"guest_{session_hash}" | |
| class UserSession: | |
| def __init__(self, user_id: Optional[str], thread_id: str, is_authenticated: bool): | |
| self.user_id = user_id | |
| self.thread_id = thread_id | |
| self.is_authenticated = is_authenticated | |
| async def get_user_session( | |
| request: Request, | |
| credentials: HTTPAuthorizationCredentials = Depends(security) | |
| ) -> UserSession: | |
| """Get user session with proper access control""" | |
| user_id = await get_current_user_id(credentials) | |
| if user_id: | |
| # Authenticated user | |
| thread_id = generate_secure_thread_id(user_id) | |
| # Store/update thread ownership in database | |
| await _threads.update_one( | |
| {"thread_id": thread_id}, | |
| { | |
| "$set": { | |
| "user_id": user_id, | |
| "last_accessed": datetime.utcnow(), | |
| "is_guest": False | |
| } | |
| }, | |
| upsert=True | |
| ) | |
| return UserSession(user_id, thread_id, True) | |
| else: | |
| # Guest user | |
| thread_id = get_guest_thread_id(request) | |
| # Store guest session (with shorter retention) | |
| await _threads.update_one( | |
| {"thread_id": thread_id}, | |
| { | |
| "$set": { | |
| "user_id": None, | |
| "last_accessed": datetime.utcnow(), | |
| "is_guest": True, | |
| "guest_ip": request.client.host | |
| } | |
| }, | |
| upsert=True | |
| ) | |
| return UserSession(None, thread_id, False) | |
| async def verify_thread_access(thread_id: str, session: UserSession) -> bool: | |
| """Verify user has access to the specified thread""" | |
| thread_doc = await _threads.find_one({"thread_id": thread_id}) | |
| if not thread_doc: | |
| return False | |
| if session.is_authenticated: | |
| # Authenticated users can only access their own threads | |
| return thread_doc.get("user_id") == session.user_id | |
| else: | |
| # Guest users can only access their session-based threads | |
| return thread_doc.get("is_guest", False) and thread_doc.get("thread_id") == session.thread_id | |
| def read_root(): | |
| return {"hello": "world"} | |
| class ChatInput(BaseModel): | |
| message: str | |
| tools: Optional[List[str]] = None | |
| thread_id: Optional[str] = None # Optional explicit thread ID | |
| async def audio_chat( | |
| request: Request, | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| credentials: HTTPAuthorizationCredentials = Depends(security) | |
| ): | |
| session = await get_user_session(request, credentials) | |
| tmp_in = tempfile.NamedTemporaryFile( | |
| suffix="." + file.filename.rsplit(".", 1)[-1], | |
| dir=BASE_TMP, | |
| delete=False, | |
| ) | |
| shutil.copyfileobj(file.file, tmp_in) | |
| tmp_in_path = tmp_in.name | |
| tmp_in.close() | |
| text_response = await handle_audio(tmp_in_path) | |
| background_tasks.add_task(os.remove, tmp_in_path) | |
| # Don't expose the actual thread ID | |
| return {"response": text_response, "session_id": "active"} | |
| async def chat( | |
| input: ChatInput, | |
| request: Request, | |
| credentials: HTTPAuthorizationCredentials = Depends(security) | |
| ): | |
| session = await get_user_session(request, credentials) | |
| # Use provided thread_id if given and user has access | |
| thread_id = session.thread_id | |
| if input.thread_id: | |
| if await verify_thread_access(input.thread_id, session): | |
| thread_id = input.thread_id | |
| else: | |
| raise HTTPException(403, "Access denied to specified thread") | |
| msg = input.message | |
| if tools := [t for t in (input.tools or []) if t]: | |
| if tools[0] != "deep_research": | |
| msg += f" (use only {', '.join(tools)})" | |
| print(f"User message from thread {thread_id}: {msg}") | |
| if tools and tools[0] == "deep_research": | |
| state = SummaryStateInput(research_topic=msg) | |
| response = deep_research_graph.invoke(state)["running_summary"] | |
| else: | |
| resp = react_graph.invoke( | |
| {"messages": [HumanMessage(content=msg)]}, | |
| {"configurable": {"thread_id": thread_id}}, | |
| ) | |
| response = resp["messages"][-1].content | |
| # Update thread access time | |
| await _threads.update_one( | |
| {"thread_id": thread_id}, | |
| {"$set": {"last_accessed": datetime.utcnow()}} | |
| ) | |
| # Return session info without exposing internal IDs | |
| return { | |
| "response": response, | |
| "session_id": "active", | |
| "is_authenticated": session.is_authenticated | |
| } | |
| class Query(BaseModel): | |
| q: str | |
| k: int = 5 | |
| async def ask(req: Query): | |
| if not req.q: | |
| raise HTTPException(400, "Missing question text") | |
| ctx = retrieve_context(req.q, k=req.k) | |
| user_prompt = prompt.invoke({"context": ctx, "question": req.q}) | |
| ans = llm_rag.invoke(user_prompt).content | |
| return {"context": ctx, "answer": ans} | |
| async def get_user_info( | |
| request: Request, | |
| credentials: HTTPAuthorizationCredentials = Depends(security) | |
| ): | |
| session = await get_user_session(request, credentials) | |
| return { | |
| "is_authenticated": session.is_authenticated, | |
| "user_id": session.user_id if session.is_authenticated else None, | |
| "session_id": "active" | |
| } | |
| # Cleanup old guest sessions (call this periodically) | |
| async def cleanup_old_sessions(): | |
| """Remove old guest sessions (admin only)""" | |
| from datetime import timedelta | |
| cutoff = datetime.utcnow() - timedelta(hours=24) | |
| result = await _threads.delete_many({ | |
| "is_guest": True, | |
| "last_accessed": {"$lt": cutoff} | |
| }) | |
| return {"deleted_sessions": result.deleted_count} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 5000))) |