quentinL52 commited on
Commit
f61631b
·
1 Parent(s): 90d6a84

security update

Browse files
Files changed (3) hide show
  1. main.py +10 -2
  2. requirements.txt +2 -1
  3. src/services/graph_service.py +44 -0
main.py CHANGED
@@ -3,6 +3,9 @@ import logging
3
  from fastapi import FastAPI, Request, HTTPException
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
6
  from pydantic import BaseModel
7
  from dotenv import load_dotenv
8
 
@@ -22,16 +25,20 @@ app = FastAPI(
22
  redirect_slashes=True,
23
  )
24
 
25
- ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "*").split(",") if os.getenv("CORS_ORIGINS") != "*" else ["*"]
26
 
27
  app.add_middleware(
28
  CORSMiddleware,
29
- allow_origins=["*"],
30
  allow_credentials=True,
31
  allow_methods=["*"],
32
  allow_headers=["*"],
33
  )
34
 
 
 
 
 
35
  class HealthCheck(BaseModel):
36
  status: str = "ok"
37
 
@@ -40,6 +47,7 @@ async def health_check():
40
  return HealthCheck()
41
 
42
  @app.post("/simulate-interview/")
 
43
  async def simulate_interview(request: Request):
44
  """
45
  This endpoint receives the interview data, instantiates the graph processor
 
3
  from fastapi import FastAPI, Request, HTTPException
4
  from fastapi.responses import JSONResponse
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from slowapi import Limiter, _rate_limit_exceeded_handler
7
+ from slowapi.util import get_remote_address
8
+ from slowapi.errors import RateLimitExceeded
9
  from pydantic import BaseModel
10
  from dotenv import load_dotenv
11
 
 
25
  redirect_slashes=True,
26
  )
27
 
28
+ ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://localhost:5173,http://localhost:8000").split(",")
29
 
30
  app.add_middleware(
31
  CORSMiddleware,
32
+ allow_origins=ALLOWED_ORIGINS,
33
  allow_credentials=True,
34
  allow_methods=["*"],
35
  allow_headers=["*"],
36
  )
37
 
38
+ limiter = Limiter(key_func=get_remote_address)
39
+ app.state.limiter = limiter
40
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
41
+
42
  class HealthCheck(BaseModel):
43
  status: str = "ok"
44
 
 
47
  return HealthCheck()
48
 
49
  @app.post("/simulate-interview/")
50
+ @limiter.limit("5/minute")
51
  async def simulate_interview(request: Request):
52
  """
53
  This endpoint receives the interview data, instantiates the graph processor
requirements.txt CHANGED
@@ -18,4 +18,5 @@ textstat
18
  chromadb
19
  sentence-transformers
20
  numpy
21
- textblob
 
 
18
  chromadb
19
  sentence-transformers
20
  numpy
21
+ textblob
22
+ slowapi
src/services/graph_service.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  from pathlib import Path
5
  from typing import TypedDict, Annotated, Dict, Any, List, Optional
6
 
 
7
  from langchain_openai import ChatOpenAI
8
  from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
9
  from langgraph.graph import StateGraph, END
@@ -71,6 +72,9 @@ class GraphInterviewProcessor:
71
  self.llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini", temperature=0.7)
72
  self.extractor = InterviewAgentExtractor(self.llm)
73
 
 
 
 
74
  self.graph = self._build_graph()
75
  logging.info("RONI Graph initialisé.")
76
 
@@ -542,8 +546,48 @@ class GraphInterviewProcessor:
542
  "cheat_metrics": cheat_metrics or {}
543
  }
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  final_state = self.graph.invoke(initial_state)
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  if not final_state or not final_state['messages']:
548
  return {"response": "Erreur système.", "status": "finished"}
549
 
 
4
  from pathlib import Path
5
  from typing import TypedDict, Annotated, Dict, Any, List, Optional
6
 
7
+ import redis
8
  from langchain_openai import ChatOpenAI
9
  from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
10
  from langgraph.graph import StateGraph, END
 
72
  self.llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini", temperature=0.7)
73
  self.extractor = InterviewAgentExtractor(self.llm)
74
 
75
+ redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
76
+ self.redis_client = redis.Redis.from_url(redis_url, decode_responses=True)
77
+
78
  self.graph = self._build_graph()
79
  logging.info("RONI Graph initialisé.")
80
 
 
546
  "cheat_metrics": cheat_metrics or {}
547
  }
548
 
549
+ # Load state from Redis
550
+ redis_key = f"interview_state:{self.user_id}"
551
+ saved_state_json = self.redis_client.get(redis_key)
552
+ if saved_state_json:
553
+ try:
554
+ saved_state = json.loads(saved_state_json)
555
+ initial_state["context"] = saved_state.get("context", {})
556
+ if saved_state.get("icebreaker_data"):
557
+ initial_state["icebreaker_data"] = IceBreakerOutput(**saved_state["icebreaker_data"])
558
+ if saved_state.get("technical_data"):
559
+ initial_state["technical_data"] = TechnicalOutput(**saved_state["technical_data"])
560
+ if saved_state.get("behavioral_data"):
561
+ initial_state["behavioral_data"] = BehavioralOutput(**saved_state["behavioral_data"])
562
+ if saved_state.get("situation_data"):
563
+ initial_state["situation_data"] = SituationOutput(**saved_state["situation_data"])
564
+ if saved_state.get("simulation_report"):
565
+ initial_state["simulation_report"] = SimulationReport(**saved_state["simulation_report"])
566
+ except Exception as e:
567
+ logger.error(f"Failed to load state from Redis: {e}")
568
+
569
  final_state = self.graph.invoke(initial_state)
570
 
571
+ # Save updated state to Redis
572
+ try:
573
+ state_to_save = {
574
+ "context": final_state.get("context", {}),
575
+ }
576
+ if final_state.get("icebreaker_data"):
577
+ state_to_save["icebreaker_data"] = final_state["icebreaker_data"].dict()
578
+ if final_state.get("technical_data"):
579
+ state_to_save["technical_data"] = final_state["technical_data"].dict()
580
+ if final_state.get("behavioral_data"):
581
+ state_to_save["behavioral_data"] = final_state["behavioral_data"].dict()
582
+ if final_state.get("situation_data"):
583
+ state_to_save["situation_data"] = final_state["situation_data"].dict()
584
+ if final_state.get("simulation_report"):
585
+ state_to_save["simulation_report"] = final_state["simulation_report"].dict()
586
+
587
+ self.redis_client.setex(redis_key, 86400, json.dumps(state_to_save)) # Expires in 24 hours
588
+ except Exception as e:
589
+ logger.error(f"Failed to save state to Redis: {e}")
590
+
591
  if not final_state or not final_state['messages']:
592
  return {"response": "Erreur système.", "status": "finished"}
593