Spaces:
Runtime error
Runtime error
| from typing import Optional, List, Dict, Any | |
| from sqlalchemy import select, update, delete, func | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from app.database.models import * | |
| from app.database.base import get_session | |
| import json | |
| from app.utils.exceptions import DatabaseError, ValidationError | |
| async def set_user(tg_id: int) -> Optional[User]: | |
| """Create a new user if not exists or return existing user.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(User).where(User.tg_id == tg_id) | |
| user = await session.scalar(query) | |
| if not user: | |
| user = User(tg_id=tg_id) | |
| session.add(user) | |
| print("User added") | |
| return user | |
| except Exception as e: | |
| raise DatabaseError(f"Error setting user: {str(e)}") | |
| async def user_register( | |
| tg_id: int, | |
| name: str, | |
| login: str, | |
| contact: str, | |
| subscribe: bool | |
| ) -> None: | |
| """Update user registration information.""" | |
| async with get_session() as session: | |
| try: | |
| query = update(User).where(User.tg_id == tg_id).values( | |
| name=name, | |
| login=login, | |
| contact=contact, | |
| subscription_status="active" if subscribe else "inactive" | |
| ) | |
| await session.execute(query) | |
| except Exception as e: | |
| raise DatabaseError(f"Error registering user: {str(e)}") | |
| async def check_login_unique(login: str) -> bool: | |
| """Check if login is available""" | |
| async with get_session() as session: | |
| user = await session.scalar( | |
| select(User).where(User.login == login) | |
| ) | |
| return user is None | |
| async def get_catalog() -> Optional[List[str]]: | |
| """Get list of all service names.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(Service).where(Service.is_active == True) | |
| result = await session.execute(query) | |
| services = result.scalars().all() | |
| return services if services else None | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting catalog: {str(e)}") | |
| async def get_service_info(service_idx: str) -> Optional[Service]: | |
| """Get detailed information about a specific service.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(Service).where( | |
| Service.id == service_idx, | |
| Service.is_active == True | |
| ) | |
| service = await session.scalar(query) | |
| return service if service else None | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting service info: {str(e)}") | |
| async def add_service(name: str, desc: str, price: int, active=bool) -> None: | |
| """Add a new service to the catalog.""" | |
| async with get_session() as session: | |
| try: | |
| service = Service( | |
| service_name=name, | |
| service_description=desc, | |
| service_price=price, | |
| is_active=active | |
| ) | |
| session.add(service) | |
| except Exception as e: | |
| raise DatabaseError(f"Error adding service: {str(e)}") | |
| async def edit_service(serv_id: int, param: str, change: Any, active: bool) -> None: | |
| """Edit an existing service.""" | |
| param_mapping = { | |
| 'name': 'service_name', | |
| 'desc': 'service_description', | |
| 'price': 'service_price' | |
| } | |
| if param not in param_mapping: | |
| raise ValueError(f"Invalid parameter: {param}") | |
| async with get_session() as session: | |
| try: | |
| query = update(Service).where( | |
| Service.id == serv_id | |
| ).values({param_mapping[param]: change}) | |
| await session.execute(query) | |
| except Exception as e: | |
| raise DatabaseError(f"Error editing service: {str(e)}") | |
| async def delete_service(serv_id: int) -> bool: | |
| """Delete a service from the catalog.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(Service).where(Service.id == serv_id) | |
| service = await session.scalar(query) | |
| if not service: | |
| return False | |
| feedback_query = select(Feedback).where(Feedback.service_id == service.id) | |
| has_feedback = await session.scalar(feedback_query) | |
| if has_feedback: | |
| update_query = ( | |
| update(Service) | |
| .where(Service.id == serv_id) | |
| .values(is_active=False) | |
| ) | |
| await session.execute(update_query) | |
| else: | |
| await session.delete(service) | |
| return True | |
| except Exception as e: | |
| raise DatabaseError(f"Error deleting service: {str(e)}") | |
| async def get_leadmagnets() -> Optional[List[str]]: | |
| """Get list of all active lead magnets.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(LeadMagnet.trigger).where(LeadMagnet.is_active == True) | |
| result = await session.execute(query) | |
| magnets = result.scalars().all() | |
| return magnets if magnets else None | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting lead magnets: {str(e)}") | |
| async def get_leadmagnet_info(trigger: str) -> Optional[LeadMagnet]: | |
| """Get detailed information about a specific lead magnet.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(LeadMagnet).where( | |
| LeadMagnet.trigger == trigger, | |
| LeadMagnet.is_active == True | |
| ) | |
| magnet = await session.scalar(query) | |
| return magnet if magnet else None | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting lead magnet info: {str(e)}") | |
| async def add_leadmagnet(trigger: str, content: str, active: bool) -> None: | |
| """Add a new lead magnet.""" | |
| async with get_session() as session: | |
| try: | |
| magnet = LeadMagnet( | |
| trigger=trigger, | |
| content=content, | |
| is_active=active | |
| ) | |
| session.add(magnet) | |
| except Exception as e: | |
| raise DatabaseError(f"Error adding lead magnet: {str(e)}") | |
| async def edit_leadmagnet(name, param, change): | |
| async with get_session() as session: | |
| replace_dict = {'trigger': 'trigger', | |
| 'content': 'content', | |
| 'status': 'is_active'} | |
| query = select(LeadMagnet).where(LeadMagnet.trigger == name) | |
| result = await session.execute(query) | |
| lead = result.scalars().first() | |
| if lead: | |
| update_query = ( | |
| update(LeadMagnet) | |
| .where(LeadMagnet.trigger == name) | |
| .values({replace_dict[param]: change}) | |
| .execution_options(synchronize_session="fetch") | |
| ) | |
| await session.execute(update_query) | |
| await session.commit() | |
| async def delete_leadmagnet(name: str) -> None: | |
| """Delete a lead magnet.""" | |
| async with get_session() as session: | |
| try: | |
| query = delete(LeadMagnet).where(LeadMagnet.trigger == name) | |
| await session.execute(query) | |
| except Exception as e: | |
| raise DatabaseError(f"Error deleting lead magnet: {str(e)}") | |
| async def get_tests() -> Optional[List[str]]: | |
| """Get list of all active tests.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(Test).where(Test.is_active == True) | |
| result = await session.execute(query) | |
| tests = result.scalars().all() | |
| return tests if tests else None | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting tests: {str(e)}") | |
| async def add_test_wo_points( | |
| name: str, | |
| test_type: str, | |
| desc: str, | |
| status: bool, | |
| completion_message: str | |
| ) -> None: | |
| """Add a new test without points system.""" | |
| async with get_session() as session: | |
| try: | |
| test = Test( | |
| test_name=name, | |
| test_type=test_type, | |
| test_description=desc, | |
| is_active=status, | |
| completion_message=completion_message | |
| ) | |
| session.add(test) | |
| except Exception as e: | |
| raise DatabaseError(f"Error adding test: {str(e)}") | |
| async def add_question_vars_wo_points(test_name: str, text: str) -> None: | |
| """Add questions and variants to a test without points system.""" | |
| async with get_session() as session: | |
| try: | |
| # Get test ID | |
| test = await session.scalar( | |
| select(Test).where(Test.test_name == test_name) | |
| ) | |
| if not test: | |
| raise ValidationError(f"Test {test_name} not found") | |
| # Split text into question and variants | |
| parts = text.split('***') | |
| if len(parts) != 2: | |
| raise ValidationError("Invalid question format") | |
| question = TestQuestion( | |
| test_id=test.id, | |
| question_content=parts[0].strip(), | |
| question_variants=parts[1].strip(), | |
| question_points="{}" # Empty JSON for non-pointed questions | |
| ) | |
| session.add(question) | |
| except Exception as e: | |
| raise DatabaseError(f"Error adding question: {str(e)}") | |
| async def add_test_result_w_points(test_name: str, text: str) -> None: | |
| """Add test results with point ranges.""" | |
| async with get_session() as session: | |
| try: | |
| test = await session.scalar( | |
| select(Test).where(Test.test_name == test_name) | |
| ) | |
| if not test: | |
| raise ValidationError(f"Test {test_name} not found") | |
| parts = text.split('\n') | |
| if len(parts) != 2: | |
| raise ValidationError("Invalid result format") | |
| point_range = parts[0].strip() | |
| min_points, max_points = map(int, point_range.split('-')) | |
| result = TestResult( | |
| test_id=test.id, | |
| min_points=min_points, | |
| max_points=max_points, | |
| result_text=parts[1].strip() | |
| ) | |
| session.add(result) | |
| except Exception as e: | |
| raise DatabaseError(f"Error adding test result: {str(e)}") | |
| async def delete_test(t_id: int) -> None: | |
| """Delete a test and all related questions and results.""" | |
| async with get_session() as session: | |
| try: | |
| test = await session.scalar( | |
| select(Test).where(Test.id == t_id) | |
| ) | |
| if test: | |
| await session.delete(test) # Cascade will handle related records | |
| except Exception as e: | |
| raise DatabaseError(f"Error deleting test: {str(e)}") | |
| async def get_test(t_id: int) -> Optional[Dict[str, Any]]: | |
| """Get complete test information including questions and results.""" | |
| async with get_session() as session: | |
| try: | |
| test_query = select(Test).where( | |
| Test.id == t_id, | |
| Test.is_active == True | |
| ) | |
| test = await session.scalar(test_query) | |
| if not test: | |
| return None | |
| questions_query = select(TestQuestion).where( | |
| TestQuestion.test_id == test.id | |
| ) | |
| results_query = select(TestResult).where( | |
| TestResult.test_id == test.id | |
| ) | |
| questions = (await session.execute(questions_query)).scalars().all() | |
| results = (await session.execute(results_query)).scalars().all() | |
| return { | |
| "id": t_id, | |
| "test": test, | |
| "questions": questions, | |
| "results": results | |
| } | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting test: {str(e)}") | |
| async def change_test_status(t_id: int, status: bool) -> None: | |
| """Change test active status.""" | |
| async with get_session() as session: | |
| try: | |
| query = update(Test).where( | |
| Test.id == t_id | |
| ).values(is_active=True if status == "Да" else False) | |
| await session.execute(query) | |
| except Exception as e: | |
| raise DatabaseError(f"Error changing test status: {str(e)}") | |
| async def add_feedback( | |
| user_id: int, | |
| service_name: str, | |
| rating: int, | |
| review: str | |
| ) -> None: | |
| """Add new feedback for a service.""" | |
| async with get_session() as session: | |
| try: | |
| service = await session.scalar( | |
| select(Service).where(Service.service_name == service_name) | |
| ) | |
| if not service: | |
| raise ValidationError(f"Service {service_name} not found") | |
| feedback = Feedback( | |
| user_id=user_id, | |
| service_id=service.id, | |
| rating=rating, | |
| review=review, | |
| is_new=True | |
| ) | |
| session.add(feedback) | |
| except Exception as e: | |
| raise DatabaseError(f"Error adding feedback: {str(e)}") | |
| async def get_new_feedback() -> Optional[List[Feedback]]: | |
| """Get all new feedback entries.""" | |
| async with get_session() as session: | |
| try: | |
| query = select(Feedback).where(Feedback.is_new == True) | |
| result = await session.execute(query) | |
| feedback = result.scalars().all() | |
| return feedback if feedback else None | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting new feedback: {str(e)}") | |
| async def mark_feedback_as_read(feedback_id: int) -> None: | |
| """Mark feedback as read.""" | |
| async with get_session() as session: | |
| try: | |
| query = update(Feedback).where( | |
| Feedback.id == feedback_id | |
| ).values(is_new=False) | |
| await session.execute(query) | |
| except Exception as e: | |
| raise DatabaseError(f"Error marking feedback as read: {str(e)}") | |
| async def get_user_info(tg_id: int) -> Optional[User]: | |
| """Get user information by Telegram ID""" | |
| async with get_session() as session: | |
| try: | |
| query = select(User).where(User.tg_id == tg_id) | |
| user = await session.scalar(query) | |
| return user | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting user info: {str(e)}") | |
| async def start_test_attempt(user_id: int, test_id: str) -> Optional[Dict[str, Any]]: | |
| """Create new test attempt and return first question""" | |
| async with get_session() as session: | |
| try: | |
| test = await session.scalar( | |
| select(Test).where( | |
| Test.id == test_id, | |
| Test.is_active == True | |
| ) | |
| ) | |
| if not test: | |
| return None | |
| user = await session.scalar( | |
| select(User).where(User.tg_id == user_id) | |
| ) | |
| if not user: | |
| return None | |
| # Create test attempt | |
| attempt = TestAttempt( | |
| user_id=user_id, | |
| test_id=test.id | |
| ) | |
| session.add(attempt) | |
| await session.flush() # Get attempt ID | |
| # Get first question | |
| question = await session.scalar( | |
| select(TestQuestion) | |
| .where(TestQuestion.test_id == test.id) | |
| .order_by(TestQuestion.id) | |
| ) | |
| await session.commit() | |
| return { | |
| "attempt_id": attempt.id, | |
| "question": question, | |
| "total_questions": await session.scalar( | |
| select(func.count()).select_from(TestQuestion) | |
| .where(TestQuestion.test_id == test.id) | |
| ) | |
| } | |
| except Exception as e: | |
| raise DatabaseError(f"Error starting test: {str(e)}") | |
| async def record_answer(attempt_id: int, question_id: int, answer: str) -> Optional[Dict[str, Any]]: | |
| """Record user's answer and return next question or result""" | |
| async with get_session() as session: | |
| try: | |
| # Get the test attempt first | |
| attempt = await session.scalar( | |
| select(TestAttempt).where(TestAttempt.id == attempt_id) | |
| ) | |
| if not attempt: | |
| raise DatabaseError("Test attempt not found") | |
| # Get question and test | |
| question = await session.scalar( | |
| select(TestQuestion).where(TestQuestion.id == question_id) | |
| ) | |
| if not question: | |
| raise DatabaseError("Question not found") | |
| test = await session.scalar( | |
| select(Test).where(Test.id == question.test_id) | |
| ) | |
| # Calculate points | |
| points = 0 | |
| if test.test_type == "С баллами": | |
| variants_raw = question.question_variants.split('\n') | |
| for variant in variants_raw: | |
| if variant.strip(): | |
| try: | |
| variant_parts = variant.strip().split('...') | |
| if len(variant_parts) == 2: | |
| variant_text, points_str = variant_parts | |
| if variant_text.strip() == answer.split("...")[0].strip(): | |
| points = int(points_str.strip()) | |
| break | |
| except ValueError: | |
| continue | |
| # Create and save answer record | |
| answer_record = TestAnswer( | |
| attempt_id=attempt_id, | |
| question_id=question_id, | |
| answer_given=answer, | |
| points_earned=points | |
| ) | |
| session.add(answer_record) | |
| await session.flush() | |
| # Get next question | |
| next_question = await session.scalar( | |
| select(TestQuestion) | |
| .where(TestQuestion.test_id == test.id) | |
| .where(TestQuestion.id > question_id) | |
| .order_by(TestQuestion.id) | |
| ) | |
| if next_question: | |
| await session.commit() | |
| return {"next_question": next_question} | |
| # If no next question, test is complete | |
| # Calculate total score | |
| answers = await session.scalars( | |
| select(TestAnswer) | |
| .where(TestAnswer.attempt_id == attempt_id) | |
| ) | |
| total_score = sum(ans.points_earned for ans in answers.all()) | |
| # Update attempt with final score | |
| attempt.score = total_score | |
| if test.test_type == "С баллами": | |
| # Get appropriate result | |
| result = await session.scalar( | |
| select(TestResult) | |
| .where(TestResult.test_id == test.id) | |
| .where(TestResult.min_points <= total_score) | |
| .where(TestResult.max_points >= total_score) | |
| ) | |
| attempt.result = result.result_text if result else None | |
| result_dict = { | |
| "completed": True, | |
| "total_points": total_score, | |
| "result": result.result_text if result else None | |
| } | |
| else: | |
| result_dict = { | |
| "completed": True, | |
| "result": test.completion_message | |
| } | |
| attempt.result = test.completion_message | |
| await session.commit() | |
| return result_dict | |
| except Exception as e: | |
| await session.rollback() | |
| raise DatabaseError(f"Error recording answer: {str(e)}") | |
| async def check_user_registered(user_id: int) -> bool: | |
| """Check if user has completed registration""" | |
| async with get_session() as session: | |
| try: | |
| user = await session.scalar( | |
| select(User) | |
| .where(User.tg_id == user_id) | |
| ) | |
| print(f"User found: {user}") # Debug print | |
| return bool(user.name) | |
| except Exception as e: | |
| raise DatabaseError(f"Error checking user registration: {str(e)}") | |
| async def get_user_test_results(user_login: str) -> List[Dict[str, Any]]: | |
| """Get all test results for a user""" | |
| async with get_session() as session: | |
| try: | |
| user = await session.scalar( | |
| select(User).where(User.login == user_login) | |
| ) | |
| if not user: | |
| return "Пользователь не найден" | |
| attempts = await session.execute( | |
| select(TestAttempt, Test) | |
| .join(Test) | |
| .where(TestAttempt.user_id == user.tg_id) | |
| .order_by(TestAttempt.completed_at.desc()) | |
| ) | |
| if attempts: | |
| return ([ | |
| { | |
| "test_name": test.test_name, | |
| "completed_at": attempt.completed_at, | |
| "score": attempt.score, | |
| "result": attempt.result | |
| } | |
| for attempt, test in attempts | |
| ]) | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting test results: {str(e)}") | |
| async def get_user_registration_info(user_id: int) -> str: | |
| """Get formatted user registration information""" | |
| async with get_session() as session: | |
| try: | |
| user = await session.scalar( | |
| select(User).where(User.tg_id == user_id) | |
| ) | |
| if not user: | |
| return "Информация о пользователе не найдена" | |
| return ( | |
| "📋 Ваша регистрационная информация:\n" | |
| f"ID: {user.tg_id}\n" | |
| f"Имя: {user.name or 'Не указано'}\n" | |
| f"Логин: {user.login or 'Не указано'}\n" | |
| f"Контакт: {user.contact or 'Не указано'}\n" | |
| f"Статус подписки: {'Активна' if user.subscription_status == 'active' else 'Неактивна'}" | |
| ) | |
| except Exception as e: | |
| raise DatabaseError(f"Error getting user info: {str(e)}") | |
| async def get_all_test_answers() -> List[Dict[str, Any]]: | |
| """Fetch all test answers with related information""" | |
| async with get_session() as session: | |
| try: | |
| result = await session.execute( | |
| select(TestAnswer, TestAttempt, User, Test, TestQuestion) | |
| .join(TestAttempt, TestAttempt.id == TestAnswer.attempt_id) | |
| .join(User, User.id == TestAttempt.user_id) | |
| .join(Test, Test.id == TestAttempt.test_id) | |
| .join(TestQuestion, TestQuestion.id == TestAnswer.question_id) | |
| .order_by(TestAttempt.completed_at.desc()) | |
| ) | |
| answers = result.fetchall() | |
| print(answers) # Debug print | |
| return [ | |
| { | |
| "answer_id": answer.id, | |
| "user_name": user.name, | |
| "test_name": test.test_name, | |
| "question": question.question_content, | |
| "answer_given": answer.answer_given, | |
| "points_earned": answer.points_earned, | |
| "completed_at": attempt.completed_at.strftime("%d.%m.%Y %H:%M") | |
| } | |
| for answer, attempt, user, test, question in answers | |
| ] | |
| except Exception as e: | |
| raise DatabaseError(f"Error fetching test answers: {str(e)}") | |
| async def own_login_check(user_id: int, login: str) -> bool: | |
| """Check if the provided login matches the user's login""" | |
| async with get_session() as session: | |
| try: | |
| user = await session.scalar( | |
| select(User).where(User.tg_id == user_id) | |
| ) | |
| if not user: | |
| return False | |
| return user.login == login | |
| except Exception as e: | |
| raise DatabaseError(f"Error checking login: {str(e)}") | |
| async def update_user_data(user_id: int, param: str, change: Any) -> None: | |
| async with get_session() as session: | |
| replace_dict = {'Имя': 'name', | |
| 'Логин': 'login', | |
| 'Контакт': 'contact', | |
| 'Статус подписки на рассылку': 'subscription_status'} | |
| query = select(User).where(User.tg_id == user_id) | |
| result = await session.execute(query) | |
| user = result.scalars().first() | |
| if user: | |
| update_query = ( | |
| update(User) | |
| .where(User.tg_id == user_id) | |
| .values({replace_dict[param]: change}) | |
| .execution_options(synchronize_session="fetch") | |
| ) | |
| await session.execute(update_query) | |
| await session.commit() | |
| async def get_broadcast_users() -> List[int]: | |
| """Fetch all users for broadcasting""" | |
| async with get_session() as session: | |
| try: | |
| result = await session.scalars( | |
| select(User.tg_id) | |
| .where(User.subscription_status == 'active') | |
| ) | |
| return result.fetchall() | |
| except Exception as e: | |
| raise DatabaseError(f"Error fetching broadcast users: {str(e)}") |