| import os |
| import uuid |
| import logging |
| import time |
| from datetime import datetime, timedelta |
| from typing import Dict, Optional, Tuple, List |
| from fastapi import UploadFile |
| from pathlib import Path |
|
|
| from app.config import settings |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class Session: |
| """Object representing a user session""" |
| def __init__(self, session_id: str, image_path: str): |
| self.session_id = session_id |
| self.image_path = image_path |
| self.created_at = datetime.now() |
| self.last_accessed = datetime.now() |
| self.questions = [] |
| |
| def is_expired(self) -> bool: |
| """Check if the session has expired""" |
| expiry_time = self.last_accessed + timedelta(seconds=settings.MAX_SESSION_AGE) |
| return datetime.now() > expiry_time |
| |
| def update_access_time(self): |
| """Update the last accessed time""" |
| self.last_accessed = datetime.now() |
| |
| def add_question(self, question: str, answer: Dict): |
| """Add a question and its answer to the session history""" |
| self.questions.append({ |
| "question": question, |
| "answer": answer, |
| "timestamp": datetime.now().isoformat() |
| }) |
| self.update_access_time() |
|
|
| class SessionService: |
| """Service for managing user sessions""" |
| |
| def __init__(self): |
| """Initialize the session service""" |
| self.sessions: Dict[str, Session] = {} |
| self.ensure_upload_dir() |
| |
| |
| self._cleanup_sessions() |
| |
| def ensure_upload_dir(self): |
| """Ensure the upload directory exists""" |
| os.makedirs(settings.UPLOAD_DIR, exist_ok=True) |
| |
| def create_session(self, file: UploadFile) -> str: |
| """ |
| Create a new session for the user |
| |
| Args: |
| file (UploadFile): The uploaded image file |
| |
| Returns: |
| str: The session ID |
| """ |
| |
| session_id = str(uuid.uuid4()) |
| |
| |
| timestamp = int(time.time()) |
| file_extension = Path(file.filename).suffix |
| filename = f"{timestamp}_{session_id}{file_extension}" |
| |
| |
| file_path = os.path.join(settings.UPLOAD_DIR, filename) |
| with open(file_path, "wb") as f: |
| f.write(file.file.read()) |
| |
| |
| self.sessions[session_id] = Session(session_id, file_path) |
| |
| logger.info(f"Created new session {session_id} with image {file_path}") |
| return session_id |
| |
| def get_session(self, session_id: str) -> Optional[Session]: |
| """ |
| Get a session by ID |
| |
| Args: |
| session_id (str): The session ID |
| |
| Returns: |
| Optional[Session]: The session, or None if not found or expired |
| """ |
| session = self.sessions.get(session_id) |
| |
| if session is None: |
| return None |
| |
| if session.is_expired(): |
| self._remove_session(session_id) |
| return None |
| |
| session.update_access_time() |
| return session |
| |
| def complete_session(self, session_id: str) -> bool: |
| """ |
| Mark a session as complete and remove its resources |
| |
| Args: |
| session_id (str): The session ID |
| |
| Returns: |
| bool: True if successful, False otherwise |
| """ |
| session = self.sessions.get(session_id) |
| if not session: |
| logger.warning(f"Cannot complete nonexistent session: {session_id}") |
| return False |
| |
| logger.info(f"Completing session {session_id}") |
| |
| try: |
| |
| if session.image_path and os.path.exists(session.image_path): |
| os.remove(session.image_path) |
| logger.info(f"Removed image file for completed session {session.image_path}") |
| |
| |
| session.image_path = None |
| return True |
| return True |
| except Exception as e: |
| logger.error(f"Error removing image file during session completion: {e}") |
| return False |
| |
| def _remove_session(self, session_id: str): |
| """ |
| Remove a session and its associated file |
| |
| Args: |
| session_id (str): The session ID |
| """ |
| session = self.sessions.pop(session_id, None) |
| if session: |
| try: |
| |
| if session.image_path and os.path.exists(session.image_path): |
| os.remove(session.image_path) |
| logger.info(f"Removed session file {session.image_path}") |
| except Exception as e: |
| logger.error(f"Error removing session file: {e}") |
| |
| def _cleanup_sessions(self): |
| """Clean up expired sessions""" |
| expired_sessions = [ |
| session_id for session_id, session in self.sessions.items() |
| if session.is_expired() |
| ] |
| |
| for session_id in expired_sessions: |
| self._remove_session(session_id) |
| |
| if expired_sessions: |
| logger.info(f"Cleaned up {len(expired_sessions)} expired sessions") |