Spaces:
Runtime error
Runtime error
| """FastAPI application for character attribute extraction with async processing.""" | |
| import asyncio | |
| import uuid | |
| from typing import List, Optional, Dict, Any | |
| from pathlib import Path | |
| import logging | |
| from datetime import datetime | |
| import json | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks, Depends | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import io | |
| from character_pipeline import create_pipeline | |
| from .pipeline.base import CharacterAttributes | |
| from .pipeline.input_loader import DatasetItem | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app initialization | |
| app = FastAPI( | |
| title="Character Attribute Extraction API", | |
| description="Production-ready API for extracting character attributes from images at scale", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global pipeline instance | |
| pipeline = None | |
| # Job storage (in production, use Redis or database) | |
| jobs = {} | |
| class JobStatus(BaseModel): | |
| job_id: str | |
| status: str # pending, processing, completed, failed | |
| progress: float | |
| result: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| created_at: datetime | |
| updated_at: datetime | |
| class BatchProcessRequest(BaseModel): | |
| image_urls: Optional[List[str]] = None | |
| dataset_path: Optional[str] = None | |
| batch_size: int = 32 | |
| use_hf_datasets: bool = True | |
| num_workers: int = 4 | |
| class SingleProcessRequest(BaseModel): | |
| image_url: Optional[str] = None | |
| tags: Optional[str] = None | |
| async def startup_event(): | |
| """Initialize the pipeline on startup.""" | |
| global pipeline | |
| logger.info("Initializing character extraction pipeline...") | |
| pipeline = create_pipeline() | |
| logger.info("Pipeline initialized successfully") | |
| async def root(): | |
| """Root endpoint with API information.""" | |
| return { | |
| "message": "Character Attribute Extraction API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/extract": "Extract attributes from single image", | |
| "/batch": "Start batch processing job", | |
| "/jobs/{job_id}": "Get job status", | |
| "/health": "Health check" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "pipeline_ready": pipeline is not None, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def extract_single(file: UploadFile = File(...)): | |
| """Extract character attributes from a single uploaded image.""" | |
| if not pipeline: | |
| raise HTTPException(status_code=503, detail="Pipeline not initialized") | |
| try: | |
| # Read and validate image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert('RGB') | |
| # Create temporary file path | |
| temp_path = f"/tmp/{uuid.uuid4()}.jpg" | |
| image.save(temp_path) | |
| # Extract attributes | |
| attributes = pipeline.extract_from_image(temp_path) | |
| # Clean up | |
| Path(temp_path).unlink(missing_ok=True) | |
| # Extract quality information from metadata | |
| quality_info = {} | |
| if hasattr(attributes, 'metadata') and attributes.metadata: | |
| quality_data = attributes.metadata.get('quality_info', {}) | |
| quality_info = { | |
| "is_good_quality": quality_data.get('is_good_quality', True), | |
| "quality_score": quality_data.get('quality_score', 1.0), | |
| "edge_cases": quality_data.get('edge_cases', []), | |
| "recommendation": quality_data.get('recommendation', 'process') | |
| } | |
| else: | |
| quality_info = { | |
| "is_good_quality": True, | |
| "quality_score": 1.0, | |
| "edge_cases": [], | |
| "recommendation": "process" | |
| } | |
| # Convert to dictionary | |
| result = { | |
| "success": True, | |
| "attributes": { | |
| "age": getattr(attributes, 'age', None), | |
| "gender": getattr(attributes, 'gender', None), | |
| "ethnicity": getattr(attributes, 'ethnicity', None), | |
| "hair_style": getattr(attributes, 'hair_style', None), | |
| "hair_color": getattr(attributes, 'hair_color', None), | |
| "hair_length": getattr(attributes, 'hair_length', None), | |
| "eye_color": getattr(attributes, 'eye_color', None), | |
| "body_type": getattr(attributes, 'body_type', None), | |
| "dress": getattr(attributes, 'dress', None) | |
| }, | |
| "confidence": getattr(attributes, 'confidence_score', 0.0), | |
| "processing_time": 0.0, | |
| "quality_info": quality_info | |
| } | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| logger.error(f"Error processing image: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def start_batch_processing(request: BatchProcessRequest, background_tasks: BackgroundTasks): | |
| """Start a batch processing job.""" | |
| if not pipeline: | |
| raise HTTPException(status_code=503, detail="Pipeline not initialized") | |
| # Generate job ID | |
| job_id = str(uuid.uuid4()) | |
| # Create job status | |
| job_status = JobStatus( | |
| job_id=job_id, | |
| status="pending", | |
| progress=0.0, | |
| created_at=datetime.now(), | |
| updated_at=datetime.now() | |
| ) | |
| jobs[job_id] = job_status | |
| # Start background processing | |
| background_tasks.add_task( | |
| process_batch_async, | |
| job_id, | |
| request | |
| ) | |
| return { | |
| "job_id": job_id, | |
| "status": "pending", | |
| "message": "Batch processing job started" | |
| } | |
| async def get_job_status(job_id: str): | |
| """Get the status of a batch processing job.""" | |
| if job_id not in jobs: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| job = jobs[job_id] | |
| return { | |
| "job_id": job.job_id, | |
| "status": job.status, | |
| "progress": job.progress, | |
| "result": job.result, | |
| "error": job.error, | |
| "created_at": job.created_at.isoformat(), | |
| "updated_at": job.updated_at.isoformat() | |
| } | |
| async def cancel_job(job_id: str): | |
| """Cancel a batch processing job.""" | |
| if job_id not in jobs: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| job = jobs[job_id] | |
| if job.status in ["completed", "failed"]: | |
| raise HTTPException(status_code=400, detail="Cannot cancel completed or failed job") | |
| job.status = "cancelled" | |
| job.updated_at = datetime.now() | |
| return {"message": "Job cancelled successfully"} | |
| async def download_results(job_id: str): | |
| """Download batch processing results as JSON file.""" | |
| if job_id not in jobs: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| job = jobs[job_id] | |
| if job.status != "completed": | |
| raise HTTPException(status_code=400, detail="Job not completed") | |
| if not job.result: | |
| raise HTTPException(status_code=404, detail="No results available") | |
| # Create temporary file | |
| temp_file = f"/tmp/results_{job_id}.json" | |
| with open(temp_file, 'w') as f: | |
| json.dump(job.result, f, indent=2) | |
| return FileResponse( | |
| temp_file, | |
| media_type="application/json", | |
| filename=f"character_attributes_{job_id}.json" | |
| ) | |
| async def process_batch_async(job_id: str, request: BatchProcessRequest): | |
| """Async function to process batch requests.""" | |
| job = jobs[job_id] | |
| try: | |
| job.status = "processing" | |
| job.updated_at = datetime.now() | |
| # Simulate batch processing | |
| if request.dataset_path: | |
| # Process from dataset path | |
| items = pipeline.input_loader.discover_dataset_items() | |
| else: | |
| # Process from URLs (would need implementation) | |
| items = [] | |
| total_items = len(items) | |
| results = [] | |
| if request.use_hf_datasets and total_items > 0: | |
| # Use HuggingFace datasets for efficient processing | |
| def process_batch_hf(batch): | |
| batch_results = [] | |
| for i, item_id in enumerate(batch['item_id']): | |
| # Simulate processing | |
| result = { | |
| 'item_id': item_id, | |
| 'attributes': { | |
| 'age': 'young_adult', | |
| 'gender': 'female', | |
| 'hair_color': 'brown' | |
| }, | |
| 'confidence': 0.85 | |
| } | |
| batch_results.append(result) | |
| # Update progress | |
| current_progress = (len(results) + i + 1) / total_items * 100 | |
| job.progress = min(current_progress, 100.0) | |
| job.updated_at = datetime.now() | |
| return {'results': batch_results} | |
| # Process using HuggingFace datasets.map() | |
| processed_dataset = pipeline.input_loader.process_with_hf_map( | |
| process_batch_hf, | |
| items=items[:100], # Limit for demo | |
| batch_size=request.batch_size, | |
| num_proc=request.num_workers | |
| ) | |
| if processed_dataset: | |
| for item in processed_dataset: | |
| results.extend(item['results']) | |
| else: | |
| # Use PyTorch DataLoader for batch processing | |
| dataloader = pipeline.input_loader.create_dataloader( | |
| items=items[:100], # Limit for demo | |
| batch_size=request.batch_size, | |
| shuffle=False | |
| ) | |
| for batch_idx, batch in enumerate(dataloader): | |
| batch_results = [] | |
| for i, item_id in enumerate(batch['item_ids']): | |
| # Simulate processing | |
| result = { | |
| 'item_id': item_id, | |
| 'attributes': { | |
| 'age': 'young_adult', | |
| 'gender': 'male', | |
| 'hair_color': 'black' | |
| }, | |
| 'confidence': 0.80 | |
| } | |
| batch_results.append(result) | |
| results.extend(batch_results) | |
| # Update progress | |
| job.progress = min((batch_idx + 1) / len(dataloader) * 100, 100.0) | |
| job.updated_at = datetime.now() | |
| # Simulate async processing | |
| await asyncio.sleep(0.1) | |
| # Job completed successfully | |
| job.status = "completed" | |
| job.progress = 100.0 | |
| job.result = { | |
| "total_processed": len(results), | |
| "results": results, | |
| "summary": { | |
| "success_rate": 100.0, | |
| "average_confidence": sum(r['confidence'] for r in results) / len(results) if results else 0 | |
| } | |
| } | |
| job.updated_at = datetime.now() | |
| except Exception as e: | |
| logger.error(f"Batch processing failed for job {job_id}: {e}") | |
| job.status = "failed" | |
| job.error = str(e) | |
| job.updated_at = datetime.now() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |