Karan6124 commited on
Commit
1ba28b2
·
1 Parent(s): 8ea5095

feat(api): filter cases by authenticated user and protect endpoints

Browse files
app/api/routes.py CHANGED
@@ -7,6 +7,8 @@ from app.services.database import get_db
7
  from app.services.context_manager import ContextManager
8
  from app.models.schemas import CaseContext, CaseStatus
9
  from app.workers.celery_worker import run_case_manager_task
 
 
10
 
11
  router = APIRouter(prefix="/api/cases", tags=["Cases"])
12
 
@@ -15,15 +17,20 @@ class CreateCaseRequest(BaseModel):
15
  constraints: Optional[List[str]] = Field(default_factory=list, example=["creator niche = tech"])
16
 
17
  @router.post("", response_model=CaseContext, status_code=status.HTTP_201_CREATED)
18
- def create_case(payload: CreateCaseRequest, db: Session = Depends(get_db)):
 
 
 
 
19
  """
20
- Initializes a new investigation case in the shared context database.
21
  """
22
  try:
23
  context = ContextManager.create_context(
24
  db=db,
25
  problem_statement=payload.problem_statement,
26
- constraints=payload.constraints
 
27
  )
28
  return context
29
  except Exception as e:
@@ -33,9 +40,13 @@ def create_case(payload: CreateCaseRequest, db: Session = Depends(get_db)):
33
  )
34
 
35
  @router.get("/{case_id}", response_model=CaseContext)
36
- def get_case(case_id: str, db: Session = Depends(get_db)):
 
 
 
 
37
  """
38
- Retrieves the complete shared context for a case.
39
  """
40
  context = ContextManager.get_context(db=db, case_id=case_id)
41
  if not context:
@@ -43,22 +54,41 @@ def get_case(case_id: str, db: Session = Depends(get_db)):
43
  status_code=status.HTTP_404_NOT_FOUND,
44
  detail=f"Case with ID {case_id} not found."
45
  )
 
 
 
 
 
 
 
 
 
46
  return context
47
 
48
  @router.post("/{case_id}/decompose", response_model=CaseContext, status_code=status.HTTP_202_ACCEPTED)
49
- def decompose_case(case_id: str, db: Session = Depends(get_db)):
 
 
 
 
50
  """
51
- Triggers the Case Manager Agent task asynchronously via Celery
52
- to break down the case problem statement into hypotheses.
53
  """
54
- # 1. Get current case context
55
- context = ContextManager.get_context(db=db, case_id=case_id)
56
- if not context:
57
  raise HTTPException(
58
  status_code=status.HTTP_404_NOT_FOUND,
59
  detail=f"Case with ID {case_id} not found."
60
  )
 
 
 
 
 
 
61
 
 
62
  if context.hypotheses:
63
  raise HTTPException(
64
  status_code=status.HTTP_400_BAD_REQUEST,
@@ -84,12 +114,16 @@ def decompose_case(case_id: str, db: Session = Depends(get_db)):
84
  )
85
 
86
  @router.get("", response_model=List[CaseContext])
87
- def get_cases(limit: int = 20, db: Session = Depends(get_db)):
 
 
 
 
88
  """
89
- Retrieves all investigation cases in the database.
90
  """
91
  try:
92
- return ContextManager.get_all_contexts(db=db, limit=limit)
93
  except Exception as e:
94
  raise HTTPException(
95
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 
7
  from app.services.context_manager import ContextManager
8
  from app.models.schemas import CaseContext, CaseStatus
9
  from app.workers.celery_worker import run_case_manager_task
10
+ from app.api.auth import get_current_user
11
+ from app.models.database_models import UserModel
12
 
13
  router = APIRouter(prefix="/api/cases", tags=["Cases"])
14
 
 
17
  constraints: Optional[List[str]] = Field(default_factory=list, example=["creator niche = tech"])
18
 
19
  @router.post("", response_model=CaseContext, status_code=status.HTTP_201_CREATED)
20
+ def create_case(
21
+ payload: CreateCaseRequest,
22
+ db: Session = Depends(get_db),
23
+ current_user: UserModel = Depends(get_current_user)
24
+ ):
25
  """
26
+ Initializes a new investigation case in the shared context database and links it to the current user.
27
  """
28
  try:
29
  context = ContextManager.create_context(
30
  db=db,
31
  problem_statement=payload.problem_statement,
32
+ constraints=payload.constraints,
33
+ user_id=current_user.id
34
  )
35
  return context
36
  except Exception as e:
 
40
  )
41
 
42
  @router.get("/{case_id}", response_model=CaseContext)
43
+ def get_case(
44
+ case_id: str,
45
+ db: Session = Depends(get_db),
46
+ current_user: UserModel = Depends(get_current_user)
47
+ ):
48
  """
49
+ Retrieves the complete shared context for a case, validating user authorization.
50
  """
51
  context = ContextManager.get_context(db=db, case_id=case_id)
52
  if not context:
 
54
  status_code=status.HTTP_404_NOT_FOUND,
55
  detail=f"Case with ID {case_id} not found."
56
  )
57
+
58
+ from app.models.database_models import CaseContextModel
59
+ case_db = db.query(CaseContextModel).filter(CaseContextModel.case_id == case_id).first()
60
+ if case_db and case_db.user_id and case_db.user_id != current_user.id:
61
+ raise HTTPException(
62
+ status_code=status.HTTP_403_FORBIDDEN,
63
+ detail="Access denied to this case."
64
+ )
65
+
66
  return context
67
 
68
  @router.post("/{case_id}/decompose", response_model=CaseContext, status_code=status.HTTP_202_ACCEPTED)
69
+ def decompose_case(
70
+ case_id: str,
71
+ db: Session = Depends(get_db),
72
+ current_user: UserModel = Depends(get_current_user)
73
+ ):
74
  """
75
+ Triggers the Case Manager Agent task asynchronously via Celery, validating user authorization.
 
76
  """
77
+ from app.models.database_models import CaseContextModel
78
+ case_db = db.query(CaseContextModel).filter(CaseContextModel.case_id == case_id).first()
79
+ if not case_db:
80
  raise HTTPException(
81
  status_code=status.HTTP_404_NOT_FOUND,
82
  detail=f"Case with ID {case_id} not found."
83
  )
84
+
85
+ if case_db.user_id and case_db.user_id != current_user.id:
86
+ raise HTTPException(
87
+ status_code=status.HTTP_403_FORBIDDEN,
88
+ detail="Access denied to this case."
89
+ )
90
 
91
+ context = ContextManager.get_context(db=db, case_id=case_id)
92
  if context.hypotheses:
93
  raise HTTPException(
94
  status_code=status.HTTP_400_BAD_REQUEST,
 
114
  )
115
 
116
  @router.get("", response_model=List[CaseContext])
117
+ def get_cases(
118
+ limit: int = 20,
119
+ db: Session = Depends(get_db),
120
+ current_user: UserModel = Depends(get_current_user)
121
+ ):
122
  """
123
+ Retrieves all investigation cases in the database belonging to the current user.
124
  """
125
  try:
126
+ return ContextManager.get_all_contexts(db=db, limit=limit, user_id=current_user.id)
127
  except Exception as e:
128
  raise HTTPException(
129
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
app/services/context_manager.py CHANGED
@@ -7,7 +7,7 @@ from app.models.database_models import CaseContextModel, FactModel, HypothesisMo
7
 
8
  class ContextManager:
9
  @staticmethod
10
- def create_context(db: Session, problem_statement: str, constraints: List[str] = None) -> CaseContext:
11
  """
12
  Creates a new case context row in the database and returns the Pydantic schema representation.
13
  """
@@ -16,6 +16,7 @@ class ContextManager:
16
  # 1. Create the base case model
17
  case_db = CaseContextModel(
18
  case_id=case_id,
 
19
  problem_statement=problem_statement,
20
  status=CaseStatus.PENDING,
21
  metadata_json={},
@@ -212,11 +213,14 @@ class ContextManager:
212
  return hyp_db
213
 
214
  @staticmethod
215
- def get_all_contexts(db: Session, limit: int = 20) -> List[CaseContext]:
216
  """
217
  Retrieves a list of all case contexts in the database.
218
  """
219
- cases_db = db.query(CaseContextModel).order_by(CaseContextModel.updated_at.desc()).limit(limit).all()
 
 
 
220
  result = []
221
  for case_db in cases_db:
222
  ctx = ContextManager.get_context(db, case_db.case_id)
 
7
 
8
  class ContextManager:
9
  @staticmethod
10
+ def create_context(db: Session, problem_statement: str, constraints: List[str] = None, user_id: str = None) -> CaseContext:
11
  """
12
  Creates a new case context row in the database and returns the Pydantic schema representation.
13
  """
 
16
  # 1. Create the base case model
17
  case_db = CaseContextModel(
18
  case_id=case_id,
19
+ user_id=user_id,
20
  problem_statement=problem_statement,
21
  status=CaseStatus.PENDING,
22
  metadata_json={},
 
213
  return hyp_db
214
 
215
  @staticmethod
216
+ def get_all_contexts(db: Session, limit: int = 20, user_id: Optional[str] = None) -> List[CaseContext]:
217
  """
218
  Retrieves a list of all case contexts in the database.
219
  """
220
+ query = db.query(CaseContextModel)
221
+ if user_id:
222
+ query = query.filter(CaseContextModel.user_id == user_id)
223
+ cases_db = query.order_by(CaseContextModel.updated_at.desc()).limit(limit).all()
224
  result = []
225
  for case_db in cases_db:
226
  ctx = ContextManager.get_context(db, case_db.case_id)