cours_nsi_term / main.py
dav74's picture
Upload main.py
40c0e77 verified
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from typing import List, Dict, Optional, Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import AnyMessage, add_messages
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg_pool import ConnectionPool
from fastapi import FastAPI, UploadFile, Form, File, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, JSON, ForeignKey, DateTime, Boolean, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
import logging
import uuid
from datetime import datetime, timedelta
import os
import jwt
from passlib.context import CryptContext
from dotenv import load_dotenv
load_dotenv()
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('uvicorn.error')
# Configuration
DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL:
logger.error("CRITICAL: DATABASE_URL is not set in .env file")
JWT_SECRET = os.getenv("JWT_SECRET")
if not JWT_SECRET:
logger.error("CRITICAL: JWT_SECRET is not set in .env file")
ALGORITHM = os.getenv("ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 60 * 24 * 30))
# Database Setup (SQLAlchemy)
Base = declarative_base()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
hashed_password = Column(String)
class UserProgress(Base):
__tablename__ = "user_course_progress"
user_id = Column(Integer, ForeignKey("users.id"), primary_key=True)
course_id = Column(String, primary_key=True)
thread_id = Column(String)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
engine = None
SessionLocal = None
import time
import asyncio
from functools import wraps
import psycopg
def retry_on_db_error(max_retries=3, delay=1):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for i in range(max_retries):
try:
return await func(*args, **kwargs)
except (psycopg.OperationalError, Exception) as e:
# Catch broad Exception because SQLAlchemy wraps the underlying psycopg error
# and sometimes it's an OperationalError from sqlalchemy.exc
error_str = str(e).lower()
if "connection" in error_str or "closed" in error_str or "ssl" in error_str:
logger.warning(f"DB Error in {func.__name__} (attempt {i+1}/{max_retries}): {e}. Retrying...")
last_exception = e
await asyncio.sleep(delay)
else:
raise e
raise last_exception
return wrapper
return decorator
def init_db():
global engine, SessionLocal
if not DATABASE_URL:
return
# Force use of Psycopg 3 driver for SQLAlchemy
sqlalchemy_url = DATABASE_URL
if sqlalchemy_url.startswith("postgresql://"):
sqlalchemy_url = sqlalchemy_url.replace("postgresql://", "postgresql+psycopg://", 1)
# Also add connect_timeout directly to the URL string
if "?" in sqlalchemy_url:
sqlalchemy_url += "&connect_timeout=60"
else:
sqlalchemy_url += "?connect_timeout=60"
max_retries = 15 # Increased to 15
retry_delay = 10
for i in range(max_retries):
try:
logger.info(f"SQLAlchemy attempt {i+1}/{max_retries} to connect to DB (timeout 60s, prepare_threshold=0)...")
engine = create_engine(
sqlalchemy_url,
pool_pre_ping=True,
pool_recycle=300,
# prepare_threshold=0 is critical for compatibility with Neon pooler/PgBouncer
connect_args={
"connect_timeout": 60,
"prepare_threshold": 0
}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# This is where the actual connection is usually established
Base.metadata.create_all(bind=engine)
logger.info("SQLAlchemy initialized successfully.")
return
except Exception as e:
logger.error(f"SQLAlchemy init error (attempt {i+1}): {e}")
if i < max_retries - 1:
logger.info(f"Retrying in {retry_delay}s... (DB might be waking up)")
time.sleep(retry_delay)
# Cap the delay at 60s
retry_delay = min(retry_delay * 1.5, 60)
else:
logger.error("SQLAlchemy failed to initialize after all retries.")
# Auth Utilities
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=ALGORITHM)
return encoded_jwt
def get_db():
if not SessionLocal:
raise HTTPException(status_code=500, detail="Database not initialized")
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except jwt.PyJWTError:
raise credentials_exception
user = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
# Langgraph Setup
pool = None
checkpointer = None
app_chatbot = None
def init_langgraph():
global pool, checkpointer, app_chatbot
if not DATABASE_URL:
return
max_retries = 15
retry_delay = 10
for i in range(max_retries):
try:
logger.info(f"Langgraph attempt {i+1}/{max_retries} to connect to DB (timeout 60s, prepare_threshold=0)...")
connection_kwargs = {
"autocommit": True,
"prepare_threshold": 0,
"connect_timeout": 60
}
# ConnectionPool from psycopg_pool
pool = ConnectionPool(
conninfo=DATABASE_URL,
max_size=20,
kwargs=connection_kwargs,
timeout=60.0,
check=ConnectionPool.check_connection # Verify connections before dispensing
)
checkpointer = PostgresSaver(pool)
checkpointer.setup()
# Compile workflow
app_chatbot = workflow.compile(checkpointer=checkpointer)
logger.info("Langgraph initialization complete.")
return
except Exception as e:
logger.error(f"Failed to initialize Langgraph (attempt {i+1}): {e}")
if i < max_retries - 1:
logger.info(f"Retrying in {retry_delay}s... (DB might be waking up)")
time.sleep(retry_delay)
retry_delay = min(retry_delay * 1.5, 60)
else:
logger.error("Langgraph failed to initialize after all retries.")
# App Setup
app = FastAPI()
# 1. CORS Middleware MUST be first
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 2. Logging Middleware to see ALL requests
@app.middleware("http")
async def log_requests(request, call_next):
logger.info(f"Incoming request: {request.method} {request.url}")
try:
response = await call_next(request)
logger.info(f"Response status: {response.status_code}")
return response
except Exception as e:
logger.error(f"Request error: {e}")
raise
@app.get("/")
async def root():
return {"status": "ok", "message": "Prof Perso API is running", "db_connected": engine is not None}
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite",
temperature=0.7,
api_key=os.getenv("GOOGLE_API_KEY")
)
"""
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
llm = ChatOpenAI(
model="mistralai/mistral-nemo",
openai_api_key = OPENROUTER_API_KEY,
openai_api_base="https://openrouter.ai/api/v1"
)
#nousresearch/hermes-3-llama-3.1-405b:free
#meta-llama/llama-3.3-70b-instruct:free
#meta-llama/llama-3.1-405b-instruct:free
#google/gemini-2.0-flash-exp:free
#qwen/qwen3-next-80b-a3b-instruct:free
#mistralai/mistral-nemo
"""
loader_term = None
prog_term = []
if os.path.exists("programme_NSI_terminale.md"):
try:
loader_term = UnstructuredMarkdownLoader("programme_NSI_terminale.md")
prog_term = loader_term.load()
logger.info("Programme NSI terminale loaded.")
except Exception as e:
logger.error(f"Error loading programme_NSI_terminale.md: {e}")
system_tutor = """
Rôle : Tu es "NSI-Tuteur", un professeur expert en Numérique et Sciences Informatiques (spécialité Première et Terminale). Ton objectif est d'aider l'élève à réviser de manière active et stimulante.
DOCUMENTS DE RÉFÉRENCE (Source de vérité) :
1. Programme Officiel NSI (Terminale) :
{prog_term}
2. Résumé du cours actuel et activités :
{course_content}
Directives de fonctionnement :
- ANALYSE DE CONTEXTE : Si un "Résumé du cours actuel" est fourni ci-dessus (dans le point 2), cela signifie que l'élève a déjà choisi un chapitre. IDENTIFIE ce chapitre et commence DIRECTEMENT par la Phase de Diagnostic sur ce thème spécifique. Ne lui demande pas quel chapitre il veut réviser s'il est déjà là !
- Si le "Résumé du cours actuel" est vide ou non pertinent, demande alors à l'élève quel chapitre il souhaite réviser parmi les programmes de Première ou Terminale fournis.
- Utilise PRINCIPALEMENT les documents ci-dessus pour tes explications et exercices (tu peux créer de nouveaux exercices en t'inspirant de ceux qui sont donnés).
- Tu ne dois jamais évoqué les documents fournis (résumé de cours, programme de terminale...) avec les élèves
- Phase de Diagnostic : Pose une question ouverte sur un concept clé du cours chargé pour évaluer le niveau de l'élève.
- Méthode Socratique : Ne donne jamais la réponse d'emblée. Guide par indices, questions ou analogies.
- Phase de Mise en Pratique : Propose des exercices courts (Python, SQL, etc.) basés sur les exemples du cours.
- Évaluation continue : Analyse les erreurs et félicite les progrès.
- Bilan : Résume les points forts et propose un plan d'action spécifique en fin de session.
Style et Ton :
- Professionnel, encourageant et concis.
- Utilise Markdown pour le code et les concepts clés.
- Adapte la complexité au niveau de l'élève.
Contraintes de sécurité :
- Refuse de répondre hors du programme NSI.
- Ne fais pas le travail à la place de l'élève.
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system_tutor),
("placeholder", "{messages}"),
]
)
class GraphState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
course_content: str
def chatbot(state: GraphState):
#course_content = state.get("course_content", "")
course_content = state["course_content"]
chain = prompt | llm
response = chain.invoke({
"messages": state["messages"],
"course_content": course_content,
"prog_term": prog_term[0].page_content if prog_term else ""
})
# Add timestamp to AI message
if not response.additional_kwargs:
response.additional_kwargs = {}
response.additional_kwargs["timestamp"] = datetime.utcnow().isoformat()
return {"messages": [response]}
def prune_history(state: GraphState):
"""Remove messages older than 3 days."""
messages = state.get("messages", [])
if not messages:
return {"messages": []}
three_days_ago = datetime.utcnow() - timedelta(days=3)
to_remove = []
for msg in messages:
ts_str = msg.additional_kwargs.get("timestamp")
if ts_str:
try:
ts = datetime.fromisoformat(ts_str)
if ts < three_days_ago:
to_remove.append(RemoveMessage(id=msg.id))
except Exception:
pass
if to_remove:
logger.info(f"Pruning {len(to_remove)} messages older than 3 days.")
return {"messages": to_remove}
workflow = StateGraph(GraphState)
workflow.add_node('prune_history', prune_history)
workflow.add_node('chatbot', chatbot)
workflow.add_edge(START, 'prune_history')
workflow.add_edge('prune_history', 'chatbot')
workflow.add_edge('chatbot', END)
# Initial calls
logger.info("Starting initializations...")
init_db()
init_langgraph()
async def perform_cleanup():
"""Perform a single pass of cleaning up old conversations and checkpoints."""
try:
logger.info("Starting cleanup of old conversations...")
if SessionLocal:
db = SessionLocal()
three_days_ago = datetime.utcnow() - timedelta(days=3)
# 1. Find old threads in UserProgress
old_progress = db.query(UserProgress).filter(UserProgress.updated_at < three_days_ago).all()
thread_ids = [p.thread_id for p in old_progress]
if thread_ids:
logger.info(f"Found {len(thread_ids)} stale threads to clean up.")
# 2. Delete from UserProgress
db.query(UserProgress).filter(UserProgress.updated_at < three_days_ago).delete()
# 3. Cleanup LangGraph checkpoints
formatted_ids = "', '".join(thread_ids)
try:
# Delete writes first
db.execute(text(f"DELETE FROM checkpoint_writes WHERE thread_id IN ('{formatted_ids}')"))
# Delete blobs
db.execute(text(f"DELETE FROM checkpoint_blobs WHERE thread_id IN ('{formatted_ids}')"))
# Delete checkpoints
db.execute(text(f"DELETE FROM checkpoints WHERE thread_id IN ('{formatted_ids}')"))
except Exception as e:
logger.warning(f"Could not clear LangGraph storage for some threads: {e}")
db.commit()
logger.info(f"Successfully cleaned up {len(thread_ids)} threads and their checkpoints.")
else:
logger.info("No stale threads found.")
db.close()
except Exception as e:
logger.error(f"Error in cleanup task: {e}")
if 'db' in locals():
db.rollback()
db.close()
async def cleanup_old_conversations():
"""Background task to remove conversations and checkpoints older than 3 days."""
while True:
await perform_cleanup()
# Run every 6 hours
await asyncio.sleep(6 * 3600)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(cleanup_old_conversations())
@app.get("/")
async def root():
return {"status": "ok", "message": "Prof Perso API is running", "db_connected": engine is not None}
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
try:
user = db.query(User).filter(User.username == form_data.username).first()
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(status_code=401, detail="Incorrect username or password")
access_token = create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}
except Exception as e:
logger.error(f"Login error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# App Endpoints
@app.get("/courses")
async def get_courses():
courses = [
{"id": "c1c.md", "title": "01-intro BD"},
{"id": "c2c.md", "title": "02-BD relationnelles"},
{"id": "c3c.md", "title": "03-langage SQL"},
{"id": "c4c.md", "title": "04-récursivité"},
{"id": "c5c.md", "title": "05-liste, piles et files"},
{"id": "c6c.md", "title": "06-les dictionnaires"},
{"id": "c7c.md", "title": "07-les arbres"},
{"id": "c8c.md", "title": "08-algo sur les arbre "},
{"id": "c9c.md", "title": "09-les graphes"},
{"id": "c10c.md", "title": "10-algo sur les graphes"},
{"id": "c11c.md", "title": "11-les protocoles de routage"},
{"id": "c12c.md", "title": "12-sécurisation des communications"},
{"id": "c13c.md", "title": "13-calculabilité - décidabilité"},
{"id": "c14c.md", "title": "14-paradigmes de programmation"},
{"id": "c15c.md", "title": "15-méthode diviser pour régner"},
{"id": "c16c.md", "title": "16-programmation dynamique"},
{"id": "c17c.md", "title": "17-recherche textuelle"},
{"id": "c18c.md", "title": "18-système sur puce"},
{"id": "c19c.md", "title": "19-les processus"}
]
return courses
class SelectCourseRequest(BaseModel):
id: str
thread_id: Optional[str] = None
@app.post("/select_course")
@retry_on_db_error()
async def select_course(
request: SelectCourseRequest,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
COURSES_DIR = "cours"
filepath = os.path.join(COURSES_DIR, request.id)
if not os.path.exists(filepath):
raise HTTPException(status_code=404, detail="Course not found")
try:
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
progress = db.query(UserProgress).filter(
UserProgress.user_id == user.id,
UserProgress.course_id == request.id
).first()
is_new = False
if not progress:
is_new = True
new_thread_id = str(uuid.uuid4())
progress = UserProgress(
user_id=user.id,
course_id=request.id,
thread_id=new_thread_id
)
db.add(progress)
else:
# Update timestamp to mark it as last accessed
progress.updated_at = datetime.utcnow()
db.commit()
config = {"configurable": {"thread_id": progress.thread_id}}
if app_chatbot:
app_chatbot.update_state(config, {"course_content": content})
# Logic to remove the last AI message if we are resuming (not is_new)
if not is_new:
state = app_chatbot.get_state(config)
if state.values and "messages" in state.values:
messages = state.values["messages"]
if messages and isinstance(messages[-1], AIMessage):
logger.info(f"Resuming course {request.id}: Removing last AI message {messages[-1].id} to replay/summarize correctly.")
app_chatbot.update_state(config, {"messages": RemoveMessage(id=messages[-1].id)})
return {
"status": "success",
"message": f"Course {request.id} selected",
"thread_id": progress.thread_id,
"is_new": is_new
}
except Exception as e:
logger.error(f"Error selecting course: {e}")
# If it's a DB connection error, re-raise so the retry decorator catches it
if "connection" in str(e).lower() or "ssl" in str(e).lower():
raise e
raise HTTPException(status_code=500, detail=str(e))
@app.post('/request')
@retry_on_db_error()
async def chatbot_request(
id: Annotated[str, Form()],
query: Annotated[str, Form()],
user: User = Depends(get_current_user)
):
config = {"configurable": {"thread_id": id}}
input_messages = []
if query:
input_messages.append(HumanMessage(
content=query,
additional_kwargs={"timestamp": datetime.utcnow().isoformat()}
))
graph_input = {"messages": input_messages}
if not app_chatbot:
raise HTTPException(status_code=500, detail="Chatbot not initialized")
# LangGraph invocation might raise psycopg errors
rep = app_chatbot.invoke(graph_input, config, stream_mode="values")
return {"response": rep['messages'][-1].content}
@app.get("/user_progress")
async def get_user_progress(user: User = Depends(get_current_user), db: Session = Depends(get_db)):
progress = db.query(UserProgress).filter(UserProgress.user_id == user.id).order_by(UserProgress.updated_at.desc()).first()
if not progress:
return {"active_course_id": None, "active_thread_id": None}
return {
"active_course_id": progress.course_id,
"active_thread_id": progress.thread_id
}
@app.get("/history/{thread_id}")
async def get_history(thread_id: str, user: User = Depends(get_current_user)):
config = {"configurable": {"thread_id": thread_id}}
if not app_chatbot:
return {"history": []}
state = app_chatbot.get_state(config)
messages = state.values.get("messages", [])
three_days_ago = datetime.utcnow() - timedelta(days=3)
formatted_messages = []
for msg in messages:
# Check if message is older than 3 days
ts_str = msg.additional_kwargs.get("timestamp")
if ts_str:
try:
ts = datetime.fromisoformat(ts_str)
if ts < three_days_ago:
continue
except Exception:
pass
if isinstance(msg, HumanMessage):
formatted_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
formatted_messages.append({"role": "assistant", "content": msg.content})
return {"history": formatted_messages}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)