Spaces:
Runtime error
Runtime error
| import os | |
| import traceback | |
| from typing import Annotated, List | |
| from fastapi import Depends, HTTPException, APIRouter | |
| from databases.firebase_db import get_firebase_user_from_token | |
| from databases.supabase_db import create_user_session, save_competitor_analysis, save_pain_point_analysis, update_user_session | |
| from models.competitor_analysis_model import CompetitorAnalysisModel | |
| from models.pain_point_model import PainPointAnalysisModel | |
| from models.reddit_models import RedditPostDataModel | |
| from models.session_model import InputInfoModel | |
| from reddit.reddit_competitor_analysis import getCompetitorAnalysisData | |
| from reddit.reddit_functions import getRedditData_with_timeout | |
| from reddit.reddit_gemini import getKeywords | |
| from reddit.reddit_pain_point_analysis import pain_point_analysis | |
| from reddit.reddit_utils import reddit_services_names | |
| import asyncio | |
| from reddit.load_env import reddit_clients | |
| from fastapi import HTTPException | |
| from reddit.scraping import fetch_submission_comments, getPostComments | |
| from utils import time_execution | |
| router = APIRouter(tags=['Reddit']) | |
| async def analyze(user_db: Annotated[dict, Depends(get_firebase_user_from_token)],request:InputInfoModel): | |
| ''' | |
| { | |
| query:"", | |
| field_inputs: | |
| { | |
| "Reddit":["Pain point analysis", "Competitor analysis"], | |
| "Twitter":["Competitor analysis"], | |
| } | |
| } | |
| ''' | |
| if not request.query: | |
| raise HTTPException(status_code=400, detail="User query must not be empty") | |
| if not request.field_inputs: | |
| raise HTTPException(status_code=400, detail="platform_names and analysis_names must not be empty") | |
| print("user_db",user_db) | |
| print("request",request) | |
| user_session = create_user_session(user_id=user_db['id'], input_info=request) | |
| print("user_session",user_session) | |
| await analyzeData(inputData=request,user_session=user_session) | |
| response_data = { | |
| 'user_query': request.query, | |
| 'input_info': request.field_inputs, | |
| 'status': 'success', | |
| 'message': 'Analysis completed successfully.' | |
| } | |
| return response_data | |
| async def getCommentsData(file_name: str): | |
| if not file_name: | |
| raise HTTPException(status_code=400, detail="User query must not be empty") | |
| keywords = await getPostComments(file_name=file_name) | |
| return keywords | |
| async def getRedditPostsData(request: RedditPostDataModel): | |
| """Requires user_query and search_keywords as arguments. | |
| Steps involved in this api: | |
| 1. get posts data from reddit | |
| 2. filter top 18 posts | |
| 3. get comments data | |
| 4. get sentiment data | |
| """ | |
| try: | |
| # Extract user_query and search_keywords from the request body | |
| user_query = request.user_query | |
| search_keywords = request.search_keywords | |
| if not user_query: | |
| raise HTTPException(status_code=400, detail="User query must not be empty") | |
| if not search_keywords: | |
| raise HTTPException(status_code=400, detail="Search keywords must not be empty") | |
| print("user_query",user_query,"search_keywords",search_keywords) | |
| result = await getRedditData_with_timeout(user_query=user_query, search_keywords=search_keywords) | |
| print('getRedditPostsData: ', result) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(f"Failed to run getRedditPostsData : {e}")) | |
| async def getCommentDataTest(postUrl: str): | |
| try: | |
| # Extract postUrl from the request body | |
| if not postUrl: | |
| raise HTTPException(status_code=400, detail="postUrl must not be empty") | |
| print("postUrl",postUrl) | |
| result = await fetch_submission_comments(url=postUrl, reddit=reddit_clients[0],is_for_competitor_analysis=False) | |
| return result if result is not None else {'status': 'fail'} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(f"Failed to run getCommentDataTest : {e}")) | |
| # delete file api | |
| async def deleteFile(fileName: str): | |
| try: | |
| # Extract fileName from the request body | |
| if not fileName: | |
| raise HTTPException(status_code=400, detail="File name must not be empty") | |
| print("fileName",fileName) | |
| if os.path.exists(fileName): | |
| os.remove(fileName) | |
| return {'status': 'success', 'message': 'File deleted successfully.'} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(f"Failed to run deleteFile : {e}")) | |
| # pain point analysis api which takes user_query and fileName as arguments | |
| def getPainPointAnalysis(user_query: str, fileName: str): | |
| try: | |
| # Extract user_query and fileName | |
| if not user_query: | |
| raise HTTPException(status_code=400, detail="User query must not be empty") | |
| if not fileName: | |
| raise HTTPException(status_code=400, detail="fileName must not be empty") | |
| print("user_query",user_query,"fileName",fileName) | |
| result=pain_point_analysis(user_query=user_query,fileName=fileName) | |
| return { | |
| 'result':result[0], | |
| 'e_time':result[2] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(f"Failed to run getPainPointAnalysis : {e}")) | |
| # pain point analysis api which takes user_query and fileName as arguments | |
| async def getCompetitorAnalysis(user_query: str, fileName: str,isSolo=True): | |
| try: | |
| # Extract user_query and fileName | |
| if not user_query: | |
| raise HTTPException(status_code=400, detail="User query must not be empty") | |
| if not fileName: | |
| raise HTTPException(status_code=400, detail="fileName must not be empty") | |
| print("user_query",user_query,"isSolo",isSolo,"fileName",fileName) | |
| result = await getCompetitorAnalysisData(user_query=user_query,fileName=fileName) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(f"Failed to run getCompetitorAnalysis : {e}")) | |
| # main method for user end | |
| async def analyzeData(inputData:InputInfoModel,user_session:dict): | |
| try: | |
| keywords = getKeywords(user_query=inputData.query) | |
| reddit_data_result = await getRedditData_with_timeout(user_query=keywords['query'], search_keywords=keywords['top_3_combinations']) | |
| update_user_session(user_session=user_session,session_info=session_info_result,process_info=process_info) | |
| services_result,session_info_result = await getServices( | |
| session_id=user_session['id'], | |
| field_inputs=inputData.field_inputs, | |
| user_query=keywords['query'], | |
| fileName=reddit_data_result['fileName'], | |
| uniqueFileId=reddit_data_result['fileUniqueId'] | |
| ) | |
| process_info = { | |
| 'keywords': keywords, | |
| 'reddit_data': reddit_data_result, | |
| 'services_result': services_result | |
| } | |
| update_user_session(user_session=user_session,session_info=session_info_result,process_info=process_info) | |
| except Exception as e: | |
| print("Failed to run analyzeData ", e) | |
| raise HTTPException(status_code=500, detail=str(f"Failed to run analyzeData : {e}")) | |
| async def getServices(session_id: int, field_inputs: dict, user_query=None, fileName=None, uniqueFileId=None): | |
| final_result = {} | |
| session_info_result = {} | |
| if "Reddit" in field_inputs: | |
| analysis_list = field_inputs['Reddit'] | |
| session_info_result['Reddit'] = [] | |
| final_result['Reddit']=[] | |
| async def run_pain_point_analysis(): | |
| pain_point_analysis_result = pain_point_analysis( | |
| user_query=user_query, fileName=fileName, uniqueFileId=uniqueFileId | |
| ) | |
| print('pain_point_analysis_result', pain_point_analysis_result) | |
| final_result['Reddit'].append({ | |
| 'pain_point_analysis': pain_point_analysis_result[2] | |
| }) | |
| print('pain_point_analysis_result[0]', pain_point_analysis_result[0]) | |
| if "details" not in pain_point_analysis_result[0].keys(): | |
| p_session = save_pain_point_analysis(data=PainPointAnalysisModel( | |
| result=pain_point_analysis_result[0], | |
| platform="Reddit", | |
| query=user_query, | |
| session_id=session_id | |
| )) | |
| session_info_result['Reddit'].append({'Pain point analysis': p_session['id']}) | |
| async def run_competitor_analysis(): | |
| competitor_analysis_result = await getCompetitorAnalysisData( | |
| user_query=user_query, fileName=fileName | |
| ) | |
| print("competitor_analysis_result", competitor_analysis_result) | |
| final_result['Reddit'].append({ | |
| 'competitor_analysis': { | |
| "competitors_data": len(competitor_analysis_result['competitors_data']), | |
| 'e_time': competitor_analysis_result['e_time'] | |
| } | |
| }) | |
| c_session = save_competitor_analysis(data=CompetitorAnalysisModel( | |
| result=competitor_analysis_result['competitors_data'] if isinstance(competitor_analysis_result['competitors_data'], list) else [competitor_analysis_result['competitors_data']], | |
| platform="Reddit", | |
| query=user_query, | |
| session_id=session_id, | |
| all_competitors=competitor_analysis_result['all_competitor_data'] | |
| )) | |
| session_info_result['Reddit'].append({'Competitor analysis': c_session['id']}) | |
| # Run analyses one by one or concurrently | |
| tasks = [] | |
| if reddit_services_names[0] in analysis_list: | |
| tasks.append(run_pain_point_analysis()) | |
| if reddit_services_names[1] in analysis_list: | |
| tasks.append(run_competitor_analysis()) | |
| # Use asyncio.gather to run tasks concurrently (or sequentially if needed) | |
| await asyncio.gather(*tasks) | |
| # delete the file if it exists | |
| if os.path.exists(fileName): | |
| os.remove(fileName) | |
| return final_result, session_info_result | |