quentinL52 commited on
Commit ·
f61631b
1
Parent(s): 90d6a84
security update
Browse files- main.py +10 -2
- requirements.txt +2 -1
- src/services/graph_service.py +44 -0
main.py
CHANGED
|
@@ -3,6 +3,9 @@ import logging
|
|
| 3 |
from fastapi import FastAPI, Request, HTTPException
|
| 4 |
from fastapi.responses import JSONResponse
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
| 6 |
from pydantic import BaseModel
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
|
@@ -22,16 +25,20 @@ app = FastAPI(
|
|
| 22 |
redirect_slashes=True,
|
| 23 |
)
|
| 24 |
|
| 25 |
-
ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "
|
| 26 |
|
| 27 |
app.add_middleware(
|
| 28 |
CORSMiddleware,
|
| 29 |
-
allow_origins=
|
| 30 |
allow_credentials=True,
|
| 31 |
allow_methods=["*"],
|
| 32 |
allow_headers=["*"],
|
| 33 |
)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
class HealthCheck(BaseModel):
|
| 36 |
status: str = "ok"
|
| 37 |
|
|
@@ -40,6 +47,7 @@ async def health_check():
|
|
| 40 |
return HealthCheck()
|
| 41 |
|
| 42 |
@app.post("/simulate-interview/")
|
|
|
|
| 43 |
async def simulate_interview(request: Request):
|
| 44 |
"""
|
| 45 |
This endpoint receives the interview data, instantiates the graph processor
|
|
|
|
| 3 |
from fastapi import FastAPI, Request, HTTPException
|
| 4 |
from fastapi.responses import JSONResponse
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 7 |
+
from slowapi.util import get_remote_address
|
| 8 |
+
from slowapi.errors import RateLimitExceeded
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
|
|
|
|
| 25 |
redirect_slashes=True,
|
| 26 |
)
|
| 27 |
|
| 28 |
+
ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:3000,http://localhost:5173,http://localhost:8000").split(",")
|
| 29 |
|
| 30 |
app.add_middleware(
|
| 31 |
CORSMiddleware,
|
| 32 |
+
allow_origins=ALLOWED_ORIGINS,
|
| 33 |
allow_credentials=True,
|
| 34 |
allow_methods=["*"],
|
| 35 |
allow_headers=["*"],
|
| 36 |
)
|
| 37 |
|
| 38 |
+
limiter = Limiter(key_func=get_remote_address)
|
| 39 |
+
app.state.limiter = limiter
|
| 40 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 41 |
+
|
| 42 |
class HealthCheck(BaseModel):
|
| 43 |
status: str = "ok"
|
| 44 |
|
|
|
|
| 47 |
return HealthCheck()
|
| 48 |
|
| 49 |
@app.post("/simulate-interview/")
|
| 50 |
+
@limiter.limit("5/minute")
|
| 51 |
async def simulate_interview(request: Request):
|
| 52 |
"""
|
| 53 |
This endpoint receives the interview data, instantiates the graph processor
|
requirements.txt
CHANGED
|
@@ -18,4 +18,5 @@ textstat
|
|
| 18 |
chromadb
|
| 19 |
sentence-transformers
|
| 20 |
numpy
|
| 21 |
-
textblob
|
|
|
|
|
|
| 18 |
chromadb
|
| 19 |
sentence-transformers
|
| 20 |
numpy
|
| 21 |
+
textblob
|
| 22 |
+
slowapi
|
src/services/graph_service.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import TypedDict, Annotated, Dict, Any, List, Optional
|
| 6 |
|
|
|
|
| 7 |
from langchain_openai import ChatOpenAI
|
| 8 |
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
| 9 |
from langgraph.graph import StateGraph, END
|
|
@@ -71,6 +72,9 @@ class GraphInterviewProcessor:
|
|
| 71 |
self.llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini", temperature=0.7)
|
| 72 |
self.extractor = InterviewAgentExtractor(self.llm)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
| 74 |
self.graph = self._build_graph()
|
| 75 |
logging.info("RONI Graph initialisé.")
|
| 76 |
|
|
@@ -542,8 +546,48 @@ class GraphInterviewProcessor:
|
|
| 542 |
"cheat_metrics": cheat_metrics or {}
|
| 543 |
}
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
final_state = self.graph.invoke(initial_state)
|
| 546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
if not final_state or not final_state['messages']:
|
| 548 |
return {"response": "Erreur système.", "status": "finished"}
|
| 549 |
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import TypedDict, Annotated, Dict, Any, List, Optional
|
| 6 |
|
| 7 |
+
import redis
|
| 8 |
from langchain_openai import ChatOpenAI
|
| 9 |
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
| 10 |
from langgraph.graph import StateGraph, END
|
|
|
|
| 72 |
self.llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini", temperature=0.7)
|
| 73 |
self.extractor = InterviewAgentExtractor(self.llm)
|
| 74 |
|
| 75 |
+
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
| 76 |
+
self.redis_client = redis.Redis.from_url(redis_url, decode_responses=True)
|
| 77 |
+
|
| 78 |
self.graph = self._build_graph()
|
| 79 |
logging.info("RONI Graph initialisé.")
|
| 80 |
|
|
|
|
| 546 |
"cheat_metrics": cheat_metrics or {}
|
| 547 |
}
|
| 548 |
|
| 549 |
+
# Load state from Redis
|
| 550 |
+
redis_key = f"interview_state:{self.user_id}"
|
| 551 |
+
saved_state_json = self.redis_client.get(redis_key)
|
| 552 |
+
if saved_state_json:
|
| 553 |
+
try:
|
| 554 |
+
saved_state = json.loads(saved_state_json)
|
| 555 |
+
initial_state["context"] = saved_state.get("context", {})
|
| 556 |
+
if saved_state.get("icebreaker_data"):
|
| 557 |
+
initial_state["icebreaker_data"] = IceBreakerOutput(**saved_state["icebreaker_data"])
|
| 558 |
+
if saved_state.get("technical_data"):
|
| 559 |
+
initial_state["technical_data"] = TechnicalOutput(**saved_state["technical_data"])
|
| 560 |
+
if saved_state.get("behavioral_data"):
|
| 561 |
+
initial_state["behavioral_data"] = BehavioralOutput(**saved_state["behavioral_data"])
|
| 562 |
+
if saved_state.get("situation_data"):
|
| 563 |
+
initial_state["situation_data"] = SituationOutput(**saved_state["situation_data"])
|
| 564 |
+
if saved_state.get("simulation_report"):
|
| 565 |
+
initial_state["simulation_report"] = SimulationReport(**saved_state["simulation_report"])
|
| 566 |
+
except Exception as e:
|
| 567 |
+
logger.error(f"Failed to load state from Redis: {e}")
|
| 568 |
+
|
| 569 |
final_state = self.graph.invoke(initial_state)
|
| 570 |
|
| 571 |
+
# Save updated state to Redis
|
| 572 |
+
try:
|
| 573 |
+
state_to_save = {
|
| 574 |
+
"context": final_state.get("context", {}),
|
| 575 |
+
}
|
| 576 |
+
if final_state.get("icebreaker_data"):
|
| 577 |
+
state_to_save["icebreaker_data"] = final_state["icebreaker_data"].dict()
|
| 578 |
+
if final_state.get("technical_data"):
|
| 579 |
+
state_to_save["technical_data"] = final_state["technical_data"].dict()
|
| 580 |
+
if final_state.get("behavioral_data"):
|
| 581 |
+
state_to_save["behavioral_data"] = final_state["behavioral_data"].dict()
|
| 582 |
+
if final_state.get("situation_data"):
|
| 583 |
+
state_to_save["situation_data"] = final_state["situation_data"].dict()
|
| 584 |
+
if final_state.get("simulation_report"):
|
| 585 |
+
state_to_save["simulation_report"] = final_state["simulation_report"].dict()
|
| 586 |
+
|
| 587 |
+
self.redis_client.setex(redis_key, 86400, json.dumps(state_to_save)) # Expires in 24 hours
|
| 588 |
+
except Exception as e:
|
| 589 |
+
logger.error(f"Failed to save state to Redis: {e}")
|
| 590 |
+
|
| 591 |
if not final_state or not final_state['messages']:
|
| 592 |
return {"response": "Erreur système.", "status": "finished"}
|
| 593 |
|