Spaces:
Running
Running
| import os | |
| import time | |
| import asyncio | |
| import importlib | |
| from fastapi import FastAPI, HTTPException, Depends, Body | |
| from typing import Optional, List | |
| from pydantic import ValidationError | |
| from app.models.registry import registry, MODEL_CONFIG | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from app.schemas.schemas import ( | |
| EnhancedDescriptionResponse, | |
| CompareRequest, | |
| CompareResponse, | |
| ModelResult, | |
| ModelInfo, | |
| InfillRequest, | |
| InfillResponse, | |
| InfillResult, | |
| GapFill, | |
| CompareInfillRequest, | |
| CompareInfillResponse, | |
| ModelInfillResult, | |
| ) | |
| from app.logic.infill_utils import ( | |
| detect_gaps, | |
| parse_infill_json, | |
| apply_fills, | |
| build_fills_dict, | |
| normalize_gaps_to_tagged, | |
| ) | |
| from app.auth.placeholder_auth import get_authenticated_user | |
| app = FastAPI( | |
| title="Multi-Model Description Enhancer", | |
| description="AI-powered service for enhancing descriptions using multiple LLMs for A/B testing", | |
| version="3.0.0" | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:5173", | |
| "http://localhost:5174", | |
| os.getenv("FRONTEND_URL", "http://localhost:5173") | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """ | |
| Startup event - models are loaded lazily on first request. | |
| No models are pre-loaded to conserve memory. | |
| """ | |
| print("Application started. Models will be loaded lazily on first request.") | |
| print(f"Available models: {registry.get_available_model_names()}") | |
| # --- Helper function to load domain logic --- | |
| def get_domain_config(domain: str): | |
| try: | |
| module = importlib.import_module(f"app.domains.{domain}.config") | |
| return module.domain_config | |
| except (ImportError, AttributeError): | |
| raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.") | |
| # --- API Endpoints --- | |
| async def read_root(): | |
| return {"message": "Welcome to the Multi-Model Description Enhancer API! Go to /docs for documentation."} | |
| async def health_check(): | |
| """Check API health and model status.""" | |
| models = registry.list_models() | |
| loaded_models = registry.get_loaded_models() | |
| active_model = registry.get_active_model() | |
| return { | |
| "status": "ok", | |
| "available_models": len(models), | |
| "loaded_models": loaded_models, | |
| "active_local_model": active_model, | |
| } | |
| async def list_models(): | |
| """List all available models with their load status.""" | |
| return registry.list_models() | |
| async def load_model(model_name: str): | |
| """ | |
| Explicitly load a model into memory. | |
| For local models: unloads any previously loaded local model first. | |
| """ | |
| if model_name not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") | |
| try: | |
| info = await registry.load_model(model_name) | |
| return {"status": "loaded", "model": info} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") | |
| async def unload_model(model_name: str): | |
| """ | |
| Explicitly unload a model from memory to free resources. | |
| """ | |
| if model_name not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") | |
| try: | |
| result = await registry.unload_model(model_name) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}") | |
| async def enhance_description( | |
| domain: str = Body(..., embed=True), | |
| data: dict = Body(..., embed=True), | |
| model: str = Body("bielik-1.5b", embed=True), | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Generate an enhanced description using a single model. | |
| - **domain**: The name of the domain (e.g., 'cars'). | |
| - **data**: A dictionary with the data for the description. | |
| - **model**: Model to use (default: bielik-1.5b) | |
| """ | |
| start_time = time.time() | |
| # Validate model | |
| if model not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model}") | |
| # Load Domain Configuration | |
| domain_config = get_domain_config(domain) | |
| DomainSchema = domain_config["schema"] | |
| create_prompt = domain_config["create_prompt"] | |
| # Validate Input Data | |
| try: | |
| validated_data = DomainSchema(**data) | |
| except ValidationError as e: | |
| raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}") | |
| # Prompt Construction | |
| chat_messages = create_prompt(validated_data) | |
| # Text Generation | |
| try: | |
| llm = await registry.get_model(model) | |
| generated_description = await llm.generate( | |
| chat_messages=chat_messages, | |
| max_new_tokens=150, | |
| temperature=0.75, | |
| top_p=0.9, | |
| ) | |
| except Exception as e: | |
| print(f"Error during text generation with {model}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") | |
| generation_time = time.time() - start_time | |
| user_email = user['email'] if user else "anonymous" | |
| return EnhancedDescriptionResponse( | |
| description=generated_description, | |
| model_used=MODEL_CONFIG[model]["id"], | |
| generation_time=round(generation_time, 2), | |
| user_email=user_email | |
| ) | |
| async def compare_models( | |
| request: CompareRequest, | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Compare outputs from multiple models for the same input. | |
| Returns results from all specified models (or all available if not specified). | |
| """ | |
| total_start = time.time() | |
| # Get models to compare | |
| available_models = registry.get_available_model_names() | |
| models_to_use = request.models if request.models else available_models | |
| # Validate requested models | |
| for model in models_to_use: | |
| if model not in available_models: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model}") | |
| # Load Domain Configuration | |
| domain_config = get_domain_config(request.domain) | |
| DomainSchema = domain_config["schema"] | |
| create_prompt = domain_config["create_prompt"] | |
| # Validate Input Data | |
| try: | |
| validated_data = DomainSchema(**request.data) | |
| except ValidationError as e: | |
| raise HTTPException(status_code=422, detail=f"Invalid data: {e}") | |
| # Prompt Construction | |
| chat_messages = create_prompt(validated_data) | |
| # Generate with each model | |
| results = [] | |
| async def generate_with_model(model_name: str) -> ModelResult: | |
| start_time = time.time() | |
| try: | |
| llm = await registry.get_model(model_name) | |
| output = await llm.generate( | |
| chat_messages=chat_messages, | |
| max_new_tokens=150, | |
| temperature=0.75, | |
| top_p=0.9, | |
| ) | |
| return ModelResult( | |
| model=model_name, | |
| output=output, | |
| time=round(time.time() - start_time, 2), | |
| type=MODEL_CONFIG[model_name]["type"], | |
| error=None | |
| ) | |
| except Exception as e: | |
| return ModelResult( | |
| model=model_name, | |
| output="", | |
| time=round(time.time() - start_time, 2), | |
| type=MODEL_CONFIG[model_name]["type"], | |
| error=str(e) | |
| ) | |
| # Run all models (sequentially to avoid memory issues) | |
| for model_name in models_to_use: | |
| result = await generate_with_model(model_name) | |
| results.append(result) | |
| return CompareResponse( | |
| domain=request.domain, | |
| results=results, | |
| total_time=round(time.time() - total_start, 2) | |
| ) | |
| async def get_user_info(user: dict = Depends(get_authenticated_user)): | |
| """Get current authenticated user information""" | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| return { | |
| "user_id": user['user_id'], | |
| "email": user['email'], | |
| "name": user.get('name', 'Unknown') | |
| } | |
| # --- Batch Infill Endpoints --- | |
| async def batch_infill( | |
| request: InfillRequest, | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Batch gap-filling for ads using a single model. | |
| Accepts items with [GAP:n] markers or ___ and returns filled text | |
| with per-gap choices and alternatives. | |
| NOTE: For texts > 6000 chars, consider chunking (not yet implemented). | |
| """ | |
| total_start = time.time() | |
| # Validate model | |
| if request.model not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}") | |
| # Load domain config for infill prompt | |
| domain_config = get_domain_config(request.domain) | |
| if "create_infill_prompt" not in domain_config: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Domain '{request.domain}' does not support infill operations" | |
| ) | |
| create_infill_prompt = domain_config["create_infill_prompt"] | |
| # Process each item | |
| results = [] | |
| error_count = 0 | |
| for item in request.items: | |
| result = await process_infill_item( | |
| item=item, | |
| model_name=request.model, | |
| options=request.options, | |
| create_infill_prompt=create_infill_prompt | |
| ) | |
| results.append(result) | |
| if result.status == "error": | |
| error_count += 1 | |
| return InfillResponse( | |
| model=request.model, | |
| results=results, | |
| total_time=round(time.time() - total_start, 2), | |
| processed_count=len(results), | |
| error_count=error_count | |
| ) | |
| async def compare_infill( | |
| request: CompareInfillRequest, | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Multi-model batch gap-filling comparison for A/B testing. | |
| Runs the same batch of items through multiple models and returns | |
| per-model results for comparison. | |
| """ | |
| total_start = time.time() | |
| # Get models to compare | |
| available_models = registry.get_available_model_names() | |
| models_to_use = request.models if request.models else available_models | |
| # Validate requested models | |
| for model in models_to_use: | |
| if model not in available_models: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model}") | |
| # Load domain config | |
| domain_config = get_domain_config(request.domain) | |
| if "create_infill_prompt" not in domain_config: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Domain '{request.domain}' does not support infill operations" | |
| ) | |
| create_infill_prompt = domain_config["create_infill_prompt"] | |
| # Process with each model (sequentially for memory safety) | |
| model_results = [] | |
| for model_name in models_to_use: | |
| model_start = time.time() | |
| results = [] | |
| error_count = 0 | |
| for item in request.items: | |
| result = await process_infill_item( | |
| item=item, | |
| model_name=model_name, | |
| options=request.options, | |
| create_infill_prompt=create_infill_prompt | |
| ) | |
| results.append(result) | |
| if result.status == "error": | |
| error_count += 1 | |
| model_results.append(ModelInfillResult( | |
| model=model_name, | |
| type=MODEL_CONFIG[model_name]["type"], | |
| results=results, | |
| time=round(time.time() - model_start, 2), | |
| error_count=error_count | |
| )) | |
| return CompareInfillResponse( | |
| domain=request.domain, | |
| models=model_results, | |
| total_time=round(time.time() - total_start, 2) | |
| ) | |
| async def process_infill_item( | |
| item, | |
| model_name: str, | |
| options, | |
| create_infill_prompt | |
| ) -> InfillResult: | |
| """ | |
| Process a single infill item. | |
| Returns InfillResult with status, filled_text, and gaps. | |
| """ | |
| try: | |
| # Normalize gaps to [GAP:n] format | |
| normalized_text, gaps = normalize_gaps_to_tagged(item.text_with_gaps) | |
| if not gaps: | |
| # No gaps found, return original text | |
| return InfillResult( | |
| id=item.id, | |
| status="ok", | |
| filled_text=item.text_with_gaps, | |
| gaps=[], | |
| error=None | |
| ) | |
| # Build prompt | |
| chat_messages = create_infill_prompt(normalized_text, options) | |
| # Generate | |
| llm = await registry.get_model(model_name) | |
| raw_output = await llm.generate( | |
| chat_messages=chat_messages, | |
| max_new_tokens=options.max_new_tokens, | |
| temperature=options.temperature, | |
| top_p=0.9, | |
| ) | |
| # Parse JSON from output | |
| parsed = parse_infill_json(raw_output) | |
| if not parsed: | |
| # JSON parsing failed | |
| return InfillResult( | |
| id=item.id, | |
| status="error", | |
| filled_text=None, | |
| gaps=[], | |
| error=f"Failed to parse JSON from model output: {raw_output[:200]}..." | |
| ) | |
| # Extract gaps and build result | |
| gap_fills = [] | |
| fills_dict = {} | |
| for gap_data in parsed.get("gaps", []): | |
| gap_fill = GapFill( | |
| index=gap_data.get("index", 0), | |
| marker=gap_data.get("marker", ""), | |
| choice=gap_data.get("choice", ""), | |
| alternatives=gap_data.get("alternatives", []) | |
| ) | |
| gap_fills.append(gap_fill) | |
| fills_dict[gap_fill.index] = gap_fill.choice | |
| # Get filled text - prefer model's version, fallback to reconstruction | |
| filled_text = parsed.get("filled_text") | |
| if not filled_text and fills_dict: | |
| filled_text = apply_fills(normalized_text, gaps, fills_dict) | |
| return InfillResult( | |
| id=item.id, | |
| status="ok", | |
| filled_text=filled_text, | |
| gaps=gap_fills, | |
| error=None | |
| ) | |
| except Exception as e: | |
| return InfillResult( | |
| id=item.id, | |
| status="error", | |
| filled_text=None, | |
| gaps=[], | |
| error=str(e) | |
| ) |