Spaces:
Running
Running
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| from typing import List, Dict, Optional, Annotated | |
| from typing_extensions import TypedDict | |
| from langgraph.graph.message import AnyMessage, add_messages | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage | |
| from langgraph.graph import END, StateGraph, START | |
| from langgraph.checkpoint.postgres import PostgresSaver | |
| from psycopg_pool import ConnectionPool | |
| from fastapi import FastAPI, UploadFile, Form, File, HTTPException, Depends, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from pydantic import BaseModel | |
| from sqlalchemy import create_engine, Column, Integer, String, JSON, ForeignKey, DateTime, Boolean, text | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, Session | |
| import logging | |
| import uuid | |
| from datetime import datetime, timedelta | |
| import os | |
| import jwt | |
| from passlib.context import CryptContext | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Logger setup | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger('uvicorn.error') | |
| # Configuration | |
| DATABASE_URL = os.getenv("DATABASE_URL") | |
| if not DATABASE_URL: | |
| logger.error("CRITICAL: DATABASE_URL is not set in .env file") | |
| JWT_SECRET = os.getenv("JWT_SECRET") | |
| if not JWT_SECRET: | |
| logger.error("CRITICAL: JWT_SECRET is not set in .env file") | |
| ALGORITHM = os.getenv("ALGORITHM", "HS256") | |
| ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 60 * 24 * 30)) | |
| # Database Setup (SQLAlchemy) | |
| Base = declarative_base() | |
| class User(Base): | |
| __tablename__ = "users" | |
| id = Column(Integer, primary_key=True, index=True) | |
| username = Column(String, unique=True, index=True) | |
| hashed_password = Column(String) | |
| class UserProgress(Base): | |
| __tablename__ = "user_course_progress" | |
| user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) | |
| course_id = Column(String, primary_key=True) | |
| thread_id = Column(String) | |
| updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) | |
| engine = None | |
| SessionLocal = None | |
| import time | |
| import asyncio | |
| from functools import wraps | |
| import psycopg | |
| def retry_on_db_error(max_retries=3, delay=1): | |
| def decorator(func): | |
| async def wrapper(*args, **kwargs): | |
| last_exception = None | |
| for i in range(max_retries): | |
| try: | |
| return await func(*args, **kwargs) | |
| except (psycopg.OperationalError, Exception) as e: | |
| # Catch broad Exception because SQLAlchemy wraps the underlying psycopg error | |
| # and sometimes it's an OperationalError from sqlalchemy.exc | |
| error_str = str(e).lower() | |
| if "connection" in error_str or "closed" in error_str or "ssl" in error_str: | |
| logger.warning(f"DB Error in {func.__name__} (attempt {i+1}/{max_retries}): {e}. Retrying...") | |
| last_exception = e | |
| await asyncio.sleep(delay) | |
| else: | |
| raise e | |
| raise last_exception | |
| return wrapper | |
| return decorator | |
| def init_db(): | |
| global engine, SessionLocal | |
| if not DATABASE_URL: | |
| return | |
| # Force use of Psycopg 3 driver for SQLAlchemy | |
| sqlalchemy_url = DATABASE_URL | |
| if sqlalchemy_url.startswith("postgresql://"): | |
| sqlalchemy_url = sqlalchemy_url.replace("postgresql://", "postgresql+psycopg://", 1) | |
| # Also add connect_timeout directly to the URL string | |
| if "?" in sqlalchemy_url: | |
| sqlalchemy_url += "&connect_timeout=60" | |
| else: | |
| sqlalchemy_url += "?connect_timeout=60" | |
| max_retries = 15 # Increased to 15 | |
| retry_delay = 10 | |
| for i in range(max_retries): | |
| try: | |
| logger.info(f"SQLAlchemy attempt {i+1}/{max_retries} to connect to DB (timeout 60s, prepare_threshold=0)...") | |
| engine = create_engine( | |
| sqlalchemy_url, | |
| pool_pre_ping=True, | |
| pool_recycle=300, | |
| # prepare_threshold=0 is critical for compatibility with Neon pooler/PgBouncer | |
| connect_args={ | |
| "connect_timeout": 60, | |
| "prepare_threshold": 0 | |
| } | |
| ) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| # This is where the actual connection is usually established | |
| Base.metadata.create_all(bind=engine) | |
| logger.info("SQLAlchemy initialized successfully.") | |
| return | |
| except Exception as e: | |
| logger.error(f"SQLAlchemy init error (attempt {i+1}): {e}") | |
| if i < max_retries - 1: | |
| logger.info(f"Retrying in {retry_delay}s... (DB might be waking up)") | |
| time.sleep(retry_delay) | |
| # Cap the delay at 60s | |
| retry_delay = min(retry_delay * 1.5, 60) | |
| else: | |
| logger.error("SQLAlchemy failed to initialize after all retries.") | |
| # Auth Utilities | |
| pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") | |
| def verify_password(plain_password, hashed_password): | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def get_password_hash(password): | |
| return pwd_context.hash(password) | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=15) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, JWT_SECRET, algorithm=ALGORITHM) | |
| return encoded_jwt | |
| def get_db(): | |
| if not SessionLocal: | |
| raise HTTPException(status_code=500, detail="Database not initialized") | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token, JWT_SECRET, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| except jwt.PyJWTError: | |
| raise credentials_exception | |
| user = db.query(User).filter(User.username == username).first() | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| # Langgraph Setup | |
| pool = None | |
| checkpointer = None | |
| app_chatbot = None | |
| def init_langgraph(): | |
| global pool, checkpointer, app_chatbot | |
| if not DATABASE_URL: | |
| return | |
| max_retries = 15 | |
| retry_delay = 10 | |
| for i in range(max_retries): | |
| try: | |
| logger.info(f"Langgraph attempt {i+1}/{max_retries} to connect to DB (timeout 60s, prepare_threshold=0)...") | |
| connection_kwargs = { | |
| "autocommit": True, | |
| "prepare_threshold": 0, | |
| "connect_timeout": 60 | |
| } | |
| # ConnectionPool from psycopg_pool | |
| pool = ConnectionPool( | |
| conninfo=DATABASE_URL, | |
| max_size=20, | |
| kwargs=connection_kwargs, | |
| timeout=60.0, | |
| check=ConnectionPool.check_connection # Verify connections before dispensing | |
| ) | |
| checkpointer = PostgresSaver(pool) | |
| checkpointer.setup() | |
| # Compile workflow | |
| app_chatbot = workflow.compile(checkpointer=checkpointer) | |
| logger.info("Langgraph initialization complete.") | |
| return | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Langgraph (attempt {i+1}): {e}") | |
| if i < max_retries - 1: | |
| logger.info(f"Retrying in {retry_delay}s... (DB might be waking up)") | |
| time.sleep(retry_delay) | |
| retry_delay = min(retry_delay * 1.5, 60) | |
| else: | |
| logger.error("Langgraph failed to initialize after all retries.") | |
| # App Setup | |
| app = FastAPI() | |
| # 1. CORS Middleware MUST be first | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 2. Logging Middleware to see ALL requests | |
| async def log_requests(request, call_next): | |
| logger.info(f"Incoming request: {request.method} {request.url}") | |
| try: | |
| response = await call_next(request) | |
| logger.info(f"Response status: {response.status_code}") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Request error: {e}") | |
| raise | |
| async def root(): | |
| return {"status": "ok", "message": "Prof Perso API is running", "db_connected": engine is not None} | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash-lite", | |
| temperature=0.7, | |
| api_key=os.getenv("GOOGLE_API_KEY") | |
| ) | |
| """ | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| llm = ChatOpenAI( | |
| model="mistralai/mistral-nemo", | |
| openai_api_key = OPENROUTER_API_KEY, | |
| openai_api_base="https://openrouter.ai/api/v1" | |
| ) | |
| #nousresearch/hermes-3-llama-3.1-405b:free | |
| #meta-llama/llama-3.3-70b-instruct:free | |
| #meta-llama/llama-3.1-405b-instruct:free | |
| #google/gemini-2.0-flash-exp:free | |
| #qwen/qwen3-next-80b-a3b-instruct:free | |
| #mistralai/mistral-nemo | |
| """ | |
| loader_term = None | |
| prog_term = [] | |
| if os.path.exists("programme_NSI_terminale.md"): | |
| try: | |
| loader_term = UnstructuredMarkdownLoader("programme_NSI_terminale.md") | |
| prog_term = loader_term.load() | |
| logger.info("Programme NSI terminale loaded.") | |
| except Exception as e: | |
| logger.error(f"Error loading programme_NSI_terminale.md: {e}") | |
| system_tutor = """ | |
| Rôle : Tu es "NSI-Tuteur", un professeur expert en Numérique et Sciences Informatiques (spécialité Première et Terminale). Ton objectif est d'aider l'élève à réviser de manière active et stimulante. | |
| DOCUMENTS DE RÉFÉRENCE (Source de vérité) : | |
| 1. Programme Officiel NSI (Terminale) : | |
| {prog_term} | |
| 2. Résumé du cours actuel et activités : | |
| {course_content} | |
| Directives de fonctionnement : | |
| - ANALYSE DE CONTEXTE : Si un "Résumé du cours actuel" est fourni ci-dessus (dans le point 2), cela signifie que l'élève a déjà choisi un chapitre. IDENTIFIE ce chapitre et commence DIRECTEMENT par la Phase de Diagnostic sur ce thème spécifique. Ne lui demande pas quel chapitre il veut réviser s'il est déjà là ! | |
| - Si le "Résumé du cours actuel" est vide ou non pertinent, demande alors à l'élève quel chapitre il souhaite réviser parmi les programmes de Première ou Terminale fournis. | |
| - Utilise PRINCIPALEMENT les documents ci-dessus pour tes explications et exercices (tu peux créer de nouveaux exercices en t'inspirant de ceux qui sont donnés). | |
| - Tu ne dois jamais évoqué les documents fournis (résumé de cours, programme de terminale...) avec les élèves | |
| - Phase de Diagnostic : Pose une question ouverte sur un concept clé du cours chargé pour évaluer le niveau de l'élève. | |
| - Méthode Socratique : Ne donne jamais la réponse d'emblée. Guide par indices, questions ou analogies. | |
| - Phase de Mise en Pratique : Propose des exercices courts (Python, SQL, etc.) basés sur les exemples du cours. | |
| - Évaluation continue : Analyse les erreurs et félicite les progrès. | |
| - Bilan : Résume les points forts et propose un plan d'action spécifique en fin de session. | |
| Style et Ton : | |
| - Professionnel, encourageant et concis. | |
| - Utilise Markdown pour le code et les concepts clés. | |
| - Adapte la complexité au niveau de l'élève. | |
| Contraintes de sécurité : | |
| - Refuse de répondre hors du programme NSI. | |
| - Ne fais pas le travail à la place de l'élève. | |
| """ | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_tutor), | |
| ("placeholder", "{messages}"), | |
| ] | |
| ) | |
| class GraphState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| course_content: str | |
| def chatbot(state: GraphState): | |
| #course_content = state.get("course_content", "") | |
| course_content = state["course_content"] | |
| chain = prompt | llm | |
| response = chain.invoke({ | |
| "messages": state["messages"], | |
| "course_content": course_content, | |
| "prog_term": prog_term[0].page_content if prog_term else "" | |
| }) | |
| # Add timestamp to AI message | |
| if not response.additional_kwargs: | |
| response.additional_kwargs = {} | |
| response.additional_kwargs["timestamp"] = datetime.utcnow().isoformat() | |
| return {"messages": [response]} | |
| def prune_history(state: GraphState): | |
| """Remove messages older than 3 days.""" | |
| messages = state.get("messages", []) | |
| if not messages: | |
| return {"messages": []} | |
| three_days_ago = datetime.utcnow() - timedelta(days=3) | |
| to_remove = [] | |
| for msg in messages: | |
| ts_str = msg.additional_kwargs.get("timestamp") | |
| if ts_str: | |
| try: | |
| ts = datetime.fromisoformat(ts_str) | |
| if ts < three_days_ago: | |
| to_remove.append(RemoveMessage(id=msg.id)) | |
| except Exception: | |
| pass | |
| if to_remove: | |
| logger.info(f"Pruning {len(to_remove)} messages older than 3 days.") | |
| return {"messages": to_remove} | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node('prune_history', prune_history) | |
| workflow.add_node('chatbot', chatbot) | |
| workflow.add_edge(START, 'prune_history') | |
| workflow.add_edge('prune_history', 'chatbot') | |
| workflow.add_edge('chatbot', END) | |
| # Initial calls | |
| logger.info("Starting initializations...") | |
| init_db() | |
| init_langgraph() | |
| async def perform_cleanup(): | |
| """Perform a single pass of cleaning up old conversations and checkpoints.""" | |
| try: | |
| logger.info("Starting cleanup of old conversations...") | |
| if SessionLocal: | |
| db = SessionLocal() | |
| three_days_ago = datetime.utcnow() - timedelta(days=3) | |
| # 1. Find old threads in UserProgress | |
| old_progress = db.query(UserProgress).filter(UserProgress.updated_at < three_days_ago).all() | |
| thread_ids = [p.thread_id for p in old_progress] | |
| if thread_ids: | |
| logger.info(f"Found {len(thread_ids)} stale threads to clean up.") | |
| # 2. Delete from UserProgress | |
| db.query(UserProgress).filter(UserProgress.updated_at < three_days_ago).delete() | |
| # 3. Cleanup LangGraph checkpoints | |
| formatted_ids = "', '".join(thread_ids) | |
| try: | |
| # Delete writes first | |
| db.execute(text(f"DELETE FROM checkpoint_writes WHERE thread_id IN ('{formatted_ids}')")) | |
| # Delete blobs | |
| db.execute(text(f"DELETE FROM checkpoint_blobs WHERE thread_id IN ('{formatted_ids}')")) | |
| # Delete checkpoints | |
| db.execute(text(f"DELETE FROM checkpoints WHERE thread_id IN ('{formatted_ids}')")) | |
| except Exception as e: | |
| logger.warning(f"Could not clear LangGraph storage for some threads: {e}") | |
| db.commit() | |
| logger.info(f"Successfully cleaned up {len(thread_ids)} threads and their checkpoints.") | |
| else: | |
| logger.info("No stale threads found.") | |
| db.close() | |
| except Exception as e: | |
| logger.error(f"Error in cleanup task: {e}") | |
| if 'db' in locals(): | |
| db.rollback() | |
| db.close() | |
| async def cleanup_old_conversations(): | |
| """Background task to remove conversations and checkpoints older than 3 days.""" | |
| while True: | |
| await perform_cleanup() | |
| # Run every 6 hours | |
| await asyncio.sleep(6 * 3600) | |
| async def startup_event(): | |
| asyncio.create_task(cleanup_old_conversations()) | |
| async def root(): | |
| return {"status": "ok", "message": "Prof Perso API is running", "db_connected": engine is not None} | |
| async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): | |
| try: | |
| user = db.query(User).filter(User.username == form_data.username).first() | |
| if not user or not verify_password(form_data.password, user.hashed_password): | |
| raise HTTPException(status_code=401, detail="Incorrect username or password") | |
| access_token = create_access_token(data={"sub": user.username}) | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| except Exception as e: | |
| logger.error(f"Login error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # App Endpoints | |
| async def get_courses(): | |
| courses = [ | |
| {"id": "c1c.md", "title": "01-intro BD"}, | |
| {"id": "c2c.md", "title": "02-BD relationnelles"}, | |
| {"id": "c3c.md", "title": "03-langage SQL"}, | |
| {"id": "c4c.md", "title": "04-récursivité"}, | |
| {"id": "c5c.md", "title": "05-liste, piles et files"}, | |
| {"id": "c6c.md", "title": "06-les dictionnaires"}, | |
| {"id": "c7c.md", "title": "07-les arbres"}, | |
| {"id": "c8c.md", "title": "08-algo sur les arbre "}, | |
| {"id": "c9c.md", "title": "09-les graphes"}, | |
| {"id": "c10c.md", "title": "10-algo sur les graphes"}, | |
| {"id": "c11c.md", "title": "11-les protocoles de routage"}, | |
| {"id": "c12c.md", "title": "12-sécurisation des communications"}, | |
| {"id": "c13c.md", "title": "13-calculabilité - décidabilité"}, | |
| {"id": "c14c.md", "title": "14-paradigmes de programmation"}, | |
| {"id": "c15c.md", "title": "15-méthode diviser pour régner"}, | |
| {"id": "c16c.md", "title": "16-programmation dynamique"}, | |
| {"id": "c17c.md", "title": "17-recherche textuelle"}, | |
| {"id": "c18c.md", "title": "18-système sur puce"}, | |
| {"id": "c19c.md", "title": "19-les processus"} | |
| ] | |
| return courses | |
| class SelectCourseRequest(BaseModel): | |
| id: str | |
| thread_id: Optional[str] = None | |
| async def select_course( | |
| request: SelectCourseRequest, | |
| user: User = Depends(get_current_user), | |
| db: Session = Depends(get_db) | |
| ): | |
| COURSES_DIR = "cours" | |
| filepath = os.path.join(COURSES_DIR, request.id) | |
| if not os.path.exists(filepath): | |
| raise HTTPException(status_code=404, detail="Course not found") | |
| try: | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| progress = db.query(UserProgress).filter( | |
| UserProgress.user_id == user.id, | |
| UserProgress.course_id == request.id | |
| ).first() | |
| is_new = False | |
| if not progress: | |
| is_new = True | |
| new_thread_id = str(uuid.uuid4()) | |
| progress = UserProgress( | |
| user_id=user.id, | |
| course_id=request.id, | |
| thread_id=new_thread_id | |
| ) | |
| db.add(progress) | |
| else: | |
| # Update timestamp to mark it as last accessed | |
| progress.updated_at = datetime.utcnow() | |
| db.commit() | |
| config = {"configurable": {"thread_id": progress.thread_id}} | |
| if app_chatbot: | |
| app_chatbot.update_state(config, {"course_content": content}) | |
| # Logic to remove the last AI message if we are resuming (not is_new) | |
| if not is_new: | |
| state = app_chatbot.get_state(config) | |
| if state.values and "messages" in state.values: | |
| messages = state.values["messages"] | |
| if messages and isinstance(messages[-1], AIMessage): | |
| logger.info(f"Resuming course {request.id}: Removing last AI message {messages[-1].id} to replay/summarize correctly.") | |
| app_chatbot.update_state(config, {"messages": RemoveMessage(id=messages[-1].id)}) | |
| return { | |
| "status": "success", | |
| "message": f"Course {request.id} selected", | |
| "thread_id": progress.thread_id, | |
| "is_new": is_new | |
| } | |
| except Exception as e: | |
| logger.error(f"Error selecting course: {e}") | |
| # If it's a DB connection error, re-raise so the retry decorator catches it | |
| if "connection" in str(e).lower() or "ssl" in str(e).lower(): | |
| raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chatbot_request( | |
| id: Annotated[str, Form()], | |
| query: Annotated[str, Form()], | |
| user: User = Depends(get_current_user) | |
| ): | |
| config = {"configurable": {"thread_id": id}} | |
| input_messages = [] | |
| if query: | |
| input_messages.append(HumanMessage( | |
| content=query, | |
| additional_kwargs={"timestamp": datetime.utcnow().isoformat()} | |
| )) | |
| graph_input = {"messages": input_messages} | |
| if not app_chatbot: | |
| raise HTTPException(status_code=500, detail="Chatbot not initialized") | |
| # LangGraph invocation might raise psycopg errors | |
| rep = app_chatbot.invoke(graph_input, config, stream_mode="values") | |
| return {"response": rep['messages'][-1].content} | |
| async def get_user_progress(user: User = Depends(get_current_user), db: Session = Depends(get_db)): | |
| progress = db.query(UserProgress).filter(UserProgress.user_id == user.id).order_by(UserProgress.updated_at.desc()).first() | |
| if not progress: | |
| return {"active_course_id": None, "active_thread_id": None} | |
| return { | |
| "active_course_id": progress.course_id, | |
| "active_thread_id": progress.thread_id | |
| } | |
| async def get_history(thread_id: str, user: User = Depends(get_current_user)): | |
| config = {"configurable": {"thread_id": thread_id}} | |
| if not app_chatbot: | |
| return {"history": []} | |
| state = app_chatbot.get_state(config) | |
| messages = state.values.get("messages", []) | |
| three_days_ago = datetime.utcnow() - timedelta(days=3) | |
| formatted_messages = [] | |
| for msg in messages: | |
| # Check if message is older than 3 days | |
| ts_str = msg.additional_kwargs.get("timestamp") | |
| if ts_str: | |
| try: | |
| ts = datetime.fromisoformat(ts_str) | |
| if ts < three_days_ago: | |
| continue | |
| except Exception: | |
| pass | |
| if isinstance(msg, HumanMessage): | |
| formatted_messages.append({"role": "user", "content": msg.content}) | |
| elif isinstance(msg, AIMessage): | |
| formatted_messages.append({"role": "assistant", "content": msg.content}) | |
| return {"history": formatted_messages} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |