from langchain_core.prompts import ChatPromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from langchain_community.document_loaders import UnstructuredMarkdownLoader from typing import List, Dict, Optional, Annotated from typing_extensions import TypedDict from langgraph.graph.message import AnyMessage, add_messages from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage from langgraph.graph import END, StateGraph, START from langgraph.checkpoint.postgres import PostgresSaver from psycopg_pool import ConnectionPool from fastapi import FastAPI, UploadFile, Form, File, HTTPException, Depends, status from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel from sqlalchemy import create_engine, Column, Integer, String, JSON, ForeignKey, DateTime, Boolean, text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session import logging import uuid from datetime import datetime, timedelta import os import jwt from passlib.context import CryptContext from dotenv import load_dotenv load_dotenv() # Logger setup logging.basicConfig(level=logging.INFO) logger = logging.getLogger('uvicorn.error') # Configuration DATABASE_URL = os.getenv("DATABASE_URL") if not DATABASE_URL: logger.error("CRITICAL: DATABASE_URL is not set in .env file") JWT_SECRET = os.getenv("JWT_SECRET") if not JWT_SECRET: logger.error("CRITICAL: JWT_SECRET is not set in .env file") ALGORITHM = os.getenv("ALGORITHM", "HS256") ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 60 * 24 * 30)) # Database Setup (SQLAlchemy) Base = declarative_base() class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) username = Column(String, unique=True, index=True) hashed_password = Column(String) class UserProgress(Base): __tablename__ = "user_course_progress" user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) course_id = Column(String, primary_key=True) thread_id = Column(String) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) engine = None SessionLocal = None import time import asyncio from functools import wraps import psycopg def retry_on_db_error(max_retries=3, delay=1): def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): last_exception = None for i in range(max_retries): try: return await func(*args, **kwargs) except (psycopg.OperationalError, Exception) as e: # Catch broad Exception because SQLAlchemy wraps the underlying psycopg error # and sometimes it's an OperationalError from sqlalchemy.exc error_str = str(e).lower() if "connection" in error_str or "closed" in error_str or "ssl" in error_str: logger.warning(f"DB Error in {func.__name__} (attempt {i+1}/{max_retries}): {e}. Retrying...") last_exception = e await asyncio.sleep(delay) else: raise e raise last_exception return wrapper return decorator def init_db(): global engine, SessionLocal if not DATABASE_URL: return # Force use of Psycopg 3 driver for SQLAlchemy sqlalchemy_url = DATABASE_URL if sqlalchemy_url.startswith("postgresql://"): sqlalchemy_url = sqlalchemy_url.replace("postgresql://", "postgresql+psycopg://", 1) # Also add connect_timeout directly to the URL string if "?" in sqlalchemy_url: sqlalchemy_url += "&connect_timeout=60" else: sqlalchemy_url += "?connect_timeout=60" max_retries = 15 # Increased to 15 retry_delay = 10 for i in range(max_retries): try: logger.info(f"SQLAlchemy attempt {i+1}/{max_retries} to connect to DB (timeout 60s, prepare_threshold=0)...") engine = create_engine( sqlalchemy_url, pool_pre_ping=True, pool_recycle=300, # prepare_threshold=0 is critical for compatibility with Neon pooler/PgBouncer connect_args={ "connect_timeout": 60, "prepare_threshold": 0 } ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # This is where the actual connection is usually established Base.metadata.create_all(bind=engine) logger.info("SQLAlchemy initialized successfully.") return except Exception as e: logger.error(f"SQLAlchemy init error (attempt {i+1}): {e}") if i < max_retries - 1: logger.info(f"Retrying in {retry_delay}s... (DB might be waking up)") time.sleep(retry_delay) # Cap the delay at 60s retry_delay = min(retry_delay * 1.5, 60) else: logger.error("SQLAlchemy failed to initialize after all retries.") # Auth Utilities pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): return pwd_context.hash(password) def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=ALGORITHM) return encoded_jwt def get_db(): if not SessionLocal: raise HTTPException(status_code=500, detail="Database not initialized") db = SessionLocal() try: yield db finally: db.close() async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception except jwt.PyJWTError: raise credentials_exception user = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user # Langgraph Setup pool = None checkpointer = None app_chatbot = None def init_langgraph(): global pool, checkpointer, app_chatbot if not DATABASE_URL: return max_retries = 15 retry_delay = 10 for i in range(max_retries): try: logger.info(f"Langgraph attempt {i+1}/{max_retries} to connect to DB (timeout 60s, prepare_threshold=0)...") connection_kwargs = { "autocommit": True, "prepare_threshold": 0, "connect_timeout": 60 } # ConnectionPool from psycopg_pool pool = ConnectionPool( conninfo=DATABASE_URL, max_size=20, kwargs=connection_kwargs, timeout=60.0, check=ConnectionPool.check_connection # Verify connections before dispensing ) checkpointer = PostgresSaver(pool) checkpointer.setup() # Compile workflow app_chatbot = workflow.compile(checkpointer=checkpointer) logger.info("Langgraph initialization complete.") return except Exception as e: logger.error(f"Failed to initialize Langgraph (attempt {i+1}): {e}") if i < max_retries - 1: logger.info(f"Retrying in {retry_delay}s... (DB might be waking up)") time.sleep(retry_delay) retry_delay = min(retry_delay * 1.5, 60) else: logger.error("Langgraph failed to initialize after all retries.") # App Setup app = FastAPI() # 1. CORS Middleware MUST be first app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 2. Logging Middleware to see ALL requests @app.middleware("http") async def log_requests(request, call_next): logger.info(f"Incoming request: {request.method} {request.url}") try: response = await call_next(request) logger.info(f"Response status: {response.status_code}") return response except Exception as e: logger.error(f"Request error: {e}") raise @app.get("/") async def root(): return {"status": "ok", "message": "Prof Perso API is running", "db_connected": engine is not None} llm = ChatGoogleGenerativeAI( model="gemini-2.5-flash-lite", temperature=0.7, api_key=os.getenv("GOOGLE_API_KEY") ) """ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") llm = ChatOpenAI( model="mistralai/mistral-nemo", openai_api_key = OPENROUTER_API_KEY, openai_api_base="https://openrouter.ai/api/v1" ) #nousresearch/hermes-3-llama-3.1-405b:free #meta-llama/llama-3.3-70b-instruct:free #meta-llama/llama-3.1-405b-instruct:free #google/gemini-2.0-flash-exp:free #qwen/qwen3-next-80b-a3b-instruct:free #mistralai/mistral-nemo """ loader_term = None prog_term = [] if os.path.exists("programme_NSI_terminale.md"): try: loader_term = UnstructuredMarkdownLoader("programme_NSI_terminale.md") prog_term = loader_term.load() logger.info("Programme NSI terminale loaded.") except Exception as e: logger.error(f"Error loading programme_NSI_terminale.md: {e}") system_tutor = """ Rôle : Tu es "NSI-Tuteur", un professeur expert en Numérique et Sciences Informatiques (spécialité Première et Terminale). Ton objectif est d'aider l'élève à réviser de manière active et stimulante. DOCUMENTS DE RÉFÉRENCE (Source de vérité) : 1. Programme Officiel NSI (Terminale) : {prog_term} 2. Résumé du cours actuel et activités : {course_content} Directives de fonctionnement : - ANALYSE DE CONTEXTE : Si un "Résumé du cours actuel" est fourni ci-dessus (dans le point 2), cela signifie que l'élève a déjà choisi un chapitre. IDENTIFIE ce chapitre et commence DIRECTEMENT par la Phase de Diagnostic sur ce thème spécifique. Ne lui demande pas quel chapitre il veut réviser s'il est déjà là ! - Si le "Résumé du cours actuel" est vide ou non pertinent, demande alors à l'élève quel chapitre il souhaite réviser parmi les programmes de Première ou Terminale fournis. - Utilise PRINCIPALEMENT les documents ci-dessus pour tes explications et exercices (tu peux créer de nouveaux exercices en t'inspirant de ceux qui sont donnés). - Tu ne dois jamais évoqué les documents fournis (résumé de cours, programme de terminale...) avec les élèves - Phase de Diagnostic : Pose une question ouverte sur un concept clé du cours chargé pour évaluer le niveau de l'élève. - Méthode Socratique : Ne donne jamais la réponse d'emblée. Guide par indices, questions ou analogies. - Phase de Mise en Pratique : Propose des exercices courts (Python, SQL, etc.) basés sur les exemples du cours. - Évaluation continue : Analyse les erreurs et félicite les progrès. - Bilan : Résume les points forts et propose un plan d'action spécifique en fin de session. Style et Ton : - Professionnel, encourageant et concis. - Utilise Markdown pour le code et les concepts clés. - Adapte la complexité au niveau de l'élève. Contraintes de sécurité : - Refuse de répondre hors du programme NSI. - Ne fais pas le travail à la place de l'élève. """ prompt = ChatPromptTemplate.from_messages( [ ("system", system_tutor), ("placeholder", "{messages}"), ] ) class GraphState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] course_content: str def chatbot(state: GraphState): #course_content = state.get("course_content", "") course_content = state["course_content"] chain = prompt | llm response = chain.invoke({ "messages": state["messages"], "course_content": course_content, "prog_term": prog_term[0].page_content if prog_term else "" }) # Add timestamp to AI message if not response.additional_kwargs: response.additional_kwargs = {} response.additional_kwargs["timestamp"] = datetime.utcnow().isoformat() return {"messages": [response]} def prune_history(state: GraphState): """Remove messages older than 3 days.""" messages = state.get("messages", []) if not messages: return {"messages": []} three_days_ago = datetime.utcnow() - timedelta(days=3) to_remove = [] for msg in messages: ts_str = msg.additional_kwargs.get("timestamp") if ts_str: try: ts = datetime.fromisoformat(ts_str) if ts < three_days_ago: to_remove.append(RemoveMessage(id=msg.id)) except Exception: pass if to_remove: logger.info(f"Pruning {len(to_remove)} messages older than 3 days.") return {"messages": to_remove} workflow = StateGraph(GraphState) workflow.add_node('prune_history', prune_history) workflow.add_node('chatbot', chatbot) workflow.add_edge(START, 'prune_history') workflow.add_edge('prune_history', 'chatbot') workflow.add_edge('chatbot', END) # Initial calls logger.info("Starting initializations...") init_db() init_langgraph() async def perform_cleanup(): """Perform a single pass of cleaning up old conversations and checkpoints.""" try: logger.info("Starting cleanup of old conversations...") if SessionLocal: db = SessionLocal() three_days_ago = datetime.utcnow() - timedelta(days=3) # 1. Find old threads in UserProgress old_progress = db.query(UserProgress).filter(UserProgress.updated_at < three_days_ago).all() thread_ids = [p.thread_id for p in old_progress] if thread_ids: logger.info(f"Found {len(thread_ids)} stale threads to clean up.") # 2. Delete from UserProgress db.query(UserProgress).filter(UserProgress.updated_at < three_days_ago).delete() # 3. Cleanup LangGraph checkpoints formatted_ids = "', '".join(thread_ids) try: # Delete writes first db.execute(text(f"DELETE FROM checkpoint_writes WHERE thread_id IN ('{formatted_ids}')")) # Delete blobs db.execute(text(f"DELETE FROM checkpoint_blobs WHERE thread_id IN ('{formatted_ids}')")) # Delete checkpoints db.execute(text(f"DELETE FROM checkpoints WHERE thread_id IN ('{formatted_ids}')")) except Exception as e: logger.warning(f"Could not clear LangGraph storage for some threads: {e}") db.commit() logger.info(f"Successfully cleaned up {len(thread_ids)} threads and their checkpoints.") else: logger.info("No stale threads found.") db.close() except Exception as e: logger.error(f"Error in cleanup task: {e}") if 'db' in locals(): db.rollback() db.close() async def cleanup_old_conversations(): """Background task to remove conversations and checkpoints older than 3 days.""" while True: await perform_cleanup() # Run every 6 hours await asyncio.sleep(6 * 3600) @app.on_event("startup") async def startup_event(): asyncio.create_task(cleanup_old_conversations()) @app.get("/") async def root(): return {"status": "ok", "message": "Prof Perso API is running", "db_connected": engine is not None} @app.post("/login") async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): try: user = db.query(User).filter(User.username == form_data.username).first() if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException(status_code=401, detail="Incorrect username or password") access_token = create_access_token(data={"sub": user.username}) return {"access_token": access_token, "token_type": "bearer"} except Exception as e: logger.error(f"Login error: {e}") raise HTTPException(status_code=500, detail=str(e)) # App Endpoints @app.get("/courses") async def get_courses(): courses = [ {"id": "c1c.md", "title": "01-intro BD"}, {"id": "c2c.md", "title": "02-BD relationnelles"}, {"id": "c3c.md", "title": "03-langage SQL"}, {"id": "c4c.md", "title": "04-récursivité"}, {"id": "c5c.md", "title": "05-liste, piles et files"}, {"id": "c6c.md", "title": "06-les dictionnaires"}, {"id": "c7c.md", "title": "07-les arbres"}, {"id": "c8c.md", "title": "08-algo sur les arbre "}, {"id": "c9c.md", "title": "09-les graphes"}, {"id": "c10c.md", "title": "10-algo sur les graphes"}, {"id": "c11c.md", "title": "11-les protocoles de routage"}, {"id": "c12c.md", "title": "12-sécurisation des communications"}, {"id": "c13c.md", "title": "13-calculabilité - décidabilité"}, {"id": "c14c.md", "title": "14-paradigmes de programmation"}, {"id": "c15c.md", "title": "15-méthode diviser pour régner"}, {"id": "c16c.md", "title": "16-programmation dynamique"}, {"id": "c17c.md", "title": "17-recherche textuelle"}, {"id": "c18c.md", "title": "18-système sur puce"}, {"id": "c19c.md", "title": "19-les processus"} ] return courses class SelectCourseRequest(BaseModel): id: str thread_id: Optional[str] = None @app.post("/select_course") @retry_on_db_error() async def select_course( request: SelectCourseRequest, user: User = Depends(get_current_user), db: Session = Depends(get_db) ): COURSES_DIR = "cours" filepath = os.path.join(COURSES_DIR, request.id) if not os.path.exists(filepath): raise HTTPException(status_code=404, detail="Course not found") try: with open(filepath, "r", encoding="utf-8") as f: content = f.read() progress = db.query(UserProgress).filter( UserProgress.user_id == user.id, UserProgress.course_id == request.id ).first() is_new = False if not progress: is_new = True new_thread_id = str(uuid.uuid4()) progress = UserProgress( user_id=user.id, course_id=request.id, thread_id=new_thread_id ) db.add(progress) else: # Update timestamp to mark it as last accessed progress.updated_at = datetime.utcnow() db.commit() config = {"configurable": {"thread_id": progress.thread_id}} if app_chatbot: app_chatbot.update_state(config, {"course_content": content}) # Logic to remove the last AI message if we are resuming (not is_new) if not is_new: state = app_chatbot.get_state(config) if state.values and "messages" in state.values: messages = state.values["messages"] if messages and isinstance(messages[-1], AIMessage): logger.info(f"Resuming course {request.id}: Removing last AI message {messages[-1].id} to replay/summarize correctly.") app_chatbot.update_state(config, {"messages": RemoveMessage(id=messages[-1].id)}) return { "status": "success", "message": f"Course {request.id} selected", "thread_id": progress.thread_id, "is_new": is_new } except Exception as e: logger.error(f"Error selecting course: {e}") # If it's a DB connection error, re-raise so the retry decorator catches it if "connection" in str(e).lower() or "ssl" in str(e).lower(): raise e raise HTTPException(status_code=500, detail=str(e)) @app.post('/request') @retry_on_db_error() async def chatbot_request( id: Annotated[str, Form()], query: Annotated[str, Form()], user: User = Depends(get_current_user) ): config = {"configurable": {"thread_id": id}} input_messages = [] if query: input_messages.append(HumanMessage( content=query, additional_kwargs={"timestamp": datetime.utcnow().isoformat()} )) graph_input = {"messages": input_messages} if not app_chatbot: raise HTTPException(status_code=500, detail="Chatbot not initialized") # LangGraph invocation might raise psycopg errors rep = app_chatbot.invoke(graph_input, config, stream_mode="values") return {"response": rep['messages'][-1].content} @app.get("/user_progress") async def get_user_progress(user: User = Depends(get_current_user), db: Session = Depends(get_db)): progress = db.query(UserProgress).filter(UserProgress.user_id == user.id).order_by(UserProgress.updated_at.desc()).first() if not progress: return {"active_course_id": None, "active_thread_id": None} return { "active_course_id": progress.course_id, "active_thread_id": progress.thread_id } @app.get("/history/{thread_id}") async def get_history(thread_id: str, user: User = Depends(get_current_user)): config = {"configurable": {"thread_id": thread_id}} if not app_chatbot: return {"history": []} state = app_chatbot.get_state(config) messages = state.values.get("messages", []) three_days_ago = datetime.utcnow() - timedelta(days=3) formatted_messages = [] for msg in messages: # Check if message is older than 3 days ts_str = msg.additional_kwargs.get("timestamp") if ts_str: try: ts = datetime.fromisoformat(ts_str) if ts < three_days_ago: continue except Exception: pass if isinstance(msg, HumanMessage): formatted_messages.append({"role": "user", "content": msg.content}) elif isinstance(msg, AIMessage): formatted_messages.append({"role": "assistant", "content": msg.content}) return {"history": formatted_messages} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)