Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI backend wrapper for HF Spaces | |
| Provides REST API endpoints while keeping Streamlit UI | |
| """ | |
| # Load environment variables FIRST before any other imports | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Dict, List, Any, Optional | |
| import pandas as pd | |
| import sys | |
| import os | |
| import uvicorn | |
| import asyncio | |
| from threading import Thread | |
| from concurrent.futures import ThreadPoolExecutor | |
| import subprocess | |
| # Add src to path for imports | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_path = os.path.join(current_dir, 'src') | |
| if src_path not in sys.path: | |
| sys.path.insert(0, src_path) | |
| from utils.data_processor import DataProcessor | |
| from utils.task_manager import get_task_manager | |
| from utils.rate_limit_middleware import RateLimitMiddleware | |
| from utils.mongodb_service import get_mongodb_service | |
| from utils.redis_service import get_redis_service | |
| from utils.task_queue import get_task_queue | |
| from utils.ip_location_service import get_ip_location_service | |
| from utils.admin_endpoints import router as admin_router | |
| app = FastAPI(title="ABSA ML Backend API", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Add rate limiting middleware | |
| app.add_middleware(RateLimitMiddleware, max_requests=100, window_seconds=60) | |
| # Include admin router | |
| app.include_router(admin_router) | |
| # Initialize processor and task manager | |
| processor = None | |
| task_manager = get_task_manager() | |
| executor = ThreadPoolExecutor(max_workers=int(os.getenv('MAX_WORKERS', '2'))) | |
| # Initialize services | |
| mongodb_service = get_mongodb_service() | |
| redis_service = get_redis_service() | |
| ip_location_service = get_ip_location_service() | |
| # Initialize task queue with processor (will be set later) | |
| task_queue = None | |
| def get_processor(): | |
| """Get or initialize processor with task manager.""" | |
| global processor, task_queue | |
| if processor is None: | |
| processor = DataProcessor() | |
| processor.set_task_manager(task_manager) | |
| # Initialize task queue with processor | |
| task_queue = get_task_queue(processor) | |
| task_queue.start_worker() | |
| return processor | |
| def calculate_timeout(num_reviews: int) -> float: | |
| """ | |
| Calculate dynamic timeout based on dataset size. | |
| Args: | |
| num_reviews: Number of reviews to process | |
| Returns: | |
| Timeout in seconds | |
| """ | |
| base_timeout = 300.0 # 5 minutes | |
| per_review_time = 0.3 # 0.3 seconds per review | |
| calculated = base_timeout + (num_reviews * per_review_time) | |
| max_timeout = 900.0 # 15 minutes absolute max | |
| return min(calculated, max_timeout) | |
| class ReviewData(BaseModel): | |
| id: int | |
| reviews_title: str | |
| review: str | |
| date: str | |
| user_id: str | |
| class ProcessRequest(BaseModel): | |
| data: List[ReviewData] | |
| options: Optional[Dict[str, Any]] = {} | |
| user_id: Optional[str] = "default" | |
| class ProcessResponse(BaseModel): | |
| status: str | |
| data: Optional[Dict[str, Any]] = None | |
| message: Optional[str] = None | |
| async def root(): | |
| return {"message": "ABSA ML Backend API", "status": "running"} | |
| async def log_session(request: dict): | |
| """ | |
| Log session metadata with IP and location (gated by Redis). | |
| Expected payload: | |
| { | |
| "device_id": "string", | |
| "user_id": "string (optional)", | |
| "ip_address": "string", | |
| "user_agent": "string (optional)" | |
| } | |
| """ | |
| device_id = request.get("device_id") | |
| user_id = request.get("user_id") | |
| ip_address = request.get("ip_address") | |
| user_agent = request.get("user_agent") | |
| if not device_id or not ip_address: | |
| raise HTTPException(status_code=400, detail="device_id and ip_address required") | |
| # Log session metadata (gated by Redis) | |
| logged = ip_location_service.log_session_metadata( | |
| device_id=device_id, | |
| ip_address=ip_address, | |
| user_id=user_id, | |
| user_agent=user_agent | |
| ) | |
| return { | |
| "status": "success", | |
| "logged": logged, | |
| "message": "Session metadata logged" if logged else "Already logged within TTL window" | |
| } | |
| async def log_event(request: dict): | |
| """ | |
| Log a telemetry event to MongoDB. | |
| Expected payload: | |
| { | |
| "event_type": "DASHBOARD_VIEW | ANALYSIS_REQUEST | etc.", | |
| "device_id": "string", | |
| "user_id": "string (optional)", | |
| "metadata": {} (optional) | |
| } | |
| """ | |
| event_type = request.get("event_type") | |
| device_id = request.get("device_id") | |
| user_id = request.get("user_id") | |
| metadata = request.get("metadata") | |
| if not event_type or not device_id: | |
| raise HTTPException(status_code=400, detail="event_type and device_id required") | |
| success = mongodb_service.log_event( | |
| event_type=event_type, | |
| device_id=device_id, | |
| user_id=user_id, | |
| metadata=metadata | |
| ) | |
| return { | |
| "status": "success" if success else "error", | |
| "logged": success | |
| } | |
| async def health_check(): | |
| try: | |
| proc = get_processor() | |
| return { | |
| "status": "healthy", | |
| "translation_service": "available" if hasattr(proc.translator, 'model') else "unavailable", | |
| "absa_service": "available" if hasattr(proc.absa_processor, 'aspect_extractor') else "unavailable", | |
| "mongodb": "connected" if mongodb_service._client else "disconnected", | |
| "redis": "connected" if redis_service.is_connected() else "disconnected" | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def submit_job(request: ProcessRequest): | |
| """ | |
| Submit ABSA job to async queue. | |
| Returns job_id for status tracking. | |
| """ | |
| try: | |
| # Get device_id and user_id from request headers or body | |
| device_id = request.options.get("device_id", "unknown") | |
| user_id = request.user_id | |
| # Convert request data to dict | |
| data_list = [item.model_dump() if hasattr(item, 'model_dump') else item.dict() for item in request.data] | |
| # Ensure task_queue is initialized | |
| get_processor() | |
| # Submit job to queue | |
| job_id = task_queue.submit_job( | |
| data={"csv_data": data_list, "options": request.options}, | |
| device_id=device_id, | |
| user_id=user_id | |
| ) | |
| return { | |
| "status": "queued", | |
| "job_id": job_id, | |
| "message": "Job submitted successfully. Use /job-status/{job_id} to check progress." | |
| } | |
| except Exception as e: | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.error(f"Failed to submit job: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_job_status(job_id: str): | |
| """Get status of queued job.""" | |
| get_processor() # Ensure task_queue is initialized | |
| status = task_queue.get_job_status(job_id) | |
| if status is None: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| response = { | |
| "job_id": job_id, | |
| "status": status | |
| } | |
| # If job is done, include result | |
| if status == "DONE": | |
| result = task_queue.get_job_result(job_id) | |
| if result: | |
| response["result"] = result | |
| return response | |
| async def process_reviews(request: ProcessRequest): | |
| """ | |
| Process reviews with cancellation support and timeout. | |
| Rate limited to 10 requests per minute for AI processing. | |
| """ | |
| # Specific rate limit for AI processing endpoint (10 per minute) | |
| user_id = request.user_id or "default" | |
| is_allowed, current_count = redis_service.check_rate_limit( | |
| identifier=user_id, | |
| max_requests=10, | |
| window_seconds=60 | |
| ) | |
| if not is_allowed: | |
| # Log rate limit hit | |
| mongodb_service.log_event( | |
| event_type="RATE_LIMIT_HIT", | |
| device_id="unknown", | |
| user_id=user_id, | |
| metadata={"endpoint": "/process-reviews", "limit": 10} | |
| ) | |
| raise HTTPException( | |
| status_code=429, | |
| detail=f"Rate limit exceeded. Maximum 10 AI processing requests per minute. Current: {current_count}/10. Please wait." | |
| ) | |
| # Create task for tracking | |
| task_id = task_manager.create_task(user_id=request.user_id) | |
| try: | |
| # Convert request data to DataFrame (using model_dump for Pydantic v2) | |
| data_list = [item.model_dump() if hasattr(item, 'model_dump') else item.dict() for item in request.data] | |
| df = pd.DataFrame(data_list) | |
| # Calculate dynamic timeout | |
| timeout = calculate_timeout(len(df)) | |
| # Update task status | |
| task_manager.update_task(task_id, status='processing', message=f'Processing {len(df)} reviews') | |
| # Run processing in background thread with timeout | |
| proc = get_processor() | |
| loop = asyncio.get_event_loop() | |
| try: | |
| results = await asyncio.wait_for( | |
| loop.run_in_executor( | |
| executor, | |
| proc.process_uploaded_data, | |
| df, | |
| task_id | |
| ), | |
| timeout=timeout | |
| ) | |
| except asyncio.TimeoutError: | |
| # Mark task as failed and cleanup | |
| task_manager.complete_task(task_id, success=False, message=f'Processing timeout ({timeout}s exceeded)') | |
| task_manager.cleanup_task(task_id) | |
| return ProcessResponse( | |
| status="timeout", | |
| message=f"Processing exceeded {timeout:.0f} second limit. Try with fewer reviews or wait and retry." | |
| ) | |
| # Check if cancelled during processing | |
| if isinstance(results, dict) and results.get('status') == 'cancelled': | |
| task_manager.mark_cancelled(task_id) | |
| task_manager.cleanup_task(task_id) | |
| return ProcessResponse( | |
| status="cancelled", | |
| message=results.get('message', 'Task was cancelled by user') | |
| ) | |
| # Check for errors | |
| if 'error' in results: | |
| task_manager.complete_task(task_id, success=False, message=str(results['error'])) | |
| raise HTTPException(status_code=400, detail=results['error']) | |
| # Mark task as complete | |
| task_manager.complete_task(task_id, success=True, message='Processing completed successfully') | |
| # Serialize for API response | |
| serialized_results = serialize_for_api(results) | |
| serialized_results['task_id'] = task_id | |
| serialized_results['timeout_used'] = timeout | |
| return ProcessResponse( | |
| status="success", | |
| data=serialized_results | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| error_detail = { | |
| "error": str(e), | |
| "traceback": traceback.format_exc(), | |
| "task_id": task_id | |
| } | |
| task_manager.complete_task(task_id, success=False, message=str(e)) | |
| task_manager.cleanup_task(task_id) | |
| # Log full error | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.error(f"Processing error for task {task_id}: {str(e)}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=error_detail) | |
| async def cancel_task(task_id: str): | |
| """Cancel a running task.""" | |
| success = task_manager.cancel_task(task_id) | |
| if success: | |
| return { | |
| "status": "success", | |
| "message": f"Task {task_id} cancellation requested", | |
| "task_id": task_id | |
| } | |
| else: | |
| return { | |
| "status": "error", | |
| "message": "Task not found or already completed", | |
| "task_id": task_id | |
| } | |
| async def get_task_status(task_id: str): | |
| """Get status of a specific task.""" | |
| status = task_manager.get_task_status(task_id) | |
| if status: | |
| return { | |
| "status": "success", | |
| "task": status | |
| } | |
| else: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| async def cancel_user_tasks(user_id: str): | |
| """Cancel all tasks for a specific user.""" | |
| count = task_manager.cancel_user_tasks(user_id) | |
| return { | |
| "status": "success", | |
| "message": f"Cancelled {count} tasks for user {user_id}", | |
| "cancelled_count": count | |
| } | |
| async def get_user_tasks(user_id: str): | |
| """Get all tasks for a specific user.""" | |
| tasks = task_manager.get_user_tasks(user_id) | |
| return { | |
| "status": "success", | |
| "user_id": user_id, | |
| "task_count": len(tasks), | |
| "tasks": tasks | |
| } | |
| async def get_task_stats(): | |
| """Get overall task statistics.""" | |
| stats = task_manager.get_stats() | |
| return { | |
| "status": "success", | |
| "stats": stats | |
| } | |
| async def cleanup_old_tasks(max_age_hours: int = 1): | |
| """Clean up old completed tasks.""" | |
| max_age_seconds = max_age_hours * 3600 | |
| task_manager.cleanup_old_tasks(max_age_seconds) | |
| return { | |
| "status": "success", | |
| "message": f"Cleaned up tasks older than {max_age_hours} hour(s)" | |
| } | |
| def serialize_for_api(results: Dict) -> Dict: | |
| """Convert complex objects to JSON-serializable format.""" | |
| serialized = {} | |
| for key, value in results.items(): | |
| if key == 'processed_data': | |
| # Convert DataFrame to dict | |
| serialized[key] = value.to_dict('records') if hasattr(value, 'to_dict') else value | |
| elif key == 'aspect_network': | |
| # Convert NetworkX graph to dict | |
| import networkx as nx | |
| if hasattr(value, 'nodes'): | |
| serialized[key] = nx.node_link_data(value) | |
| else: | |
| serialized[key] = value | |
| elif hasattr(value, 'to_dict'): | |
| # Convert DataFrames | |
| serialized[key] = value.to_dict('records') | |
| elif isinstance(value, pd.DataFrame): | |
| serialized[key] = value.to_dict('records') | |
| else: | |
| # Keep as is for basic types | |
| serialized[key] = value | |
| return serialized | |
| def run_streamlit(): | |
| """Run Streamlit in a separate thread (optional - only if app file exists)""" | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Check if streamlit app exists | |
| streamlit_files = ["frontend_light.py", "app_enhanced.py", "app.py"] | |
| streamlit_app = None | |
| for file in streamlit_files: | |
| if os.path.exists(file): | |
| streamlit_app = file | |
| break | |
| if streamlit_app: | |
| logger.info(f"Starting Streamlit UI with {streamlit_app}") | |
| subprocess.run([ | |
| "streamlit", "run", streamlit_app, | |
| "--server.port=8502", | |
| "--server.address=0.0.0.0" | |
| ]) | |
| else: | |
| logger.info("No Streamlit app found. Running FastAPI only (API-only mode)") | |
| if __name__ == "__main__": | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Try to start Streamlit in background (optional) | |
| streamlit_available = any(os.path.exists(f) for f in ["frontend_light.py", "app_enhanced.py", "app.py"]) | |
| if streamlit_available: | |
| logger.info("๐จ Starting Streamlit UI in background...") | |
| streamlit_thread = Thread(target=run_streamlit, daemon=True) | |
| streamlit_thread.start() | |
| else: | |
| logger.info("๐ก Running in API-only mode (no Streamlit UI)") | |
| # Start FastAPI | |
| logger.info("๐ Starting FastAPI server on http://0.0.0.0:7860") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |