Spaces:
Sleeping
Sleeping
Patryk Studzinski
feat: Add main backup and simplified service implementations with API endpoints
9222e8a | import os | |
| import time | |
| import asyncio | |
| import importlib | |
| import subprocess | |
| import sys | |
| from fastapi import FastAPI, HTTPException, Depends, Body | |
| from typing import Optional, List | |
| from pydantic import ValidationError | |
| # llama-cpp-python installed at runtime with CUDA support | |
| try: | |
| import llama_cpp | |
| except ImportError: | |
| print("[STARTUP] Installing llama-cpp-python with CUDA...") | |
| env = os.environ.copy() | |
| result = subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--prefer-binary", | |
| "--index-url", "https://abetlen.github.io/llama-cpp-python/whl/cu121", | |
| "llama-cpp-python[server]>=0.3.16"], | |
| capture_output=True, | |
| text=True | |
| ) | |
| if result.returncode != 0: | |
| print("[STARTUP] CUDA wheel failed, trying CPU fallback...") | |
| print(f"[STARTUP] Error details: {result.stderr[:500]}") | |
| subprocess.run([sys.executable, "-m", "pip", "install", "--quiet", "llama-cpp-python>=0.3.16"], check=False) | |
| else: | |
| print("[STARTUP] llama-cpp-python with CUDA installed") | |
| from app.models.registry import registry, MODEL_CONFIG | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from app.schemas.schemas import ( | |
| EnhancedDescriptionResponse, | |
| CompareRequest, | |
| CompareResponse, | |
| ModelResult, | |
| ModelInfo, | |
| InfillRequest, | |
| InfillResponse, | |
| InfillResult, | |
| GapFill, | |
| CompareInfillRequest, | |
| CompareInfillResponse, | |
| ModelInfillResult, | |
| ) | |
| from app.logic.infill_utils import ( | |
| detect_gaps, | |
| parse_infill_response, | |
| apply_fills, | |
| build_fills_dict, | |
| normalize_gaps_to_tagged, | |
| ) | |
| from app.auth.placeholder_auth import get_authenticated_user | |
| app = FastAPI( | |
| title="Multi-Model Description Enhancer", | |
| description="AI-powered service for enhancing descriptions using multiple LLMs for A/B testing", | |
| version="3.0.0" | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:5173", | |
| "http://localhost:5174", | |
| os.getenv("FRONTEND_URL", "http://localhost:5173") | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """ | |
| Startup event - models are loaded lazily on first request. | |
| No models are pre-loaded to conserve memory. | |
| """ | |
| print("Application started. Models will be loaded lazily on first request.") | |
| print(f"Available models: {registry.get_available_model_names()}") | |
| try: | |
| import torch | |
| gpu_available = torch.cuda.is_available() | |
| gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A" | |
| print(f"GPU available: {gpu_available}, Device: {gpu_name}") | |
| except ImportError: | |
| print("PyTorch not available for GPU check") | |
| except Exception as e: | |
| print(f"GPU check failed: {e}") | |
| # --- Helper function to load domain logic --- | |
| def get_domain_config(domain: str): | |
| try: | |
| module = importlib.import_module(f"app.domains.{domain}.config") | |
| return module.domain_config | |
| except (ImportError, AttributeError): | |
| raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.") | |
| # --- API Endpoints --- | |
| async def read_root(): | |
| return {"message": "Welcome to the Multi-Model Description Enhancer API! Go to /docs for documentation."} | |
| async def health_check(): | |
| """Check API health and model status.""" | |
| models = registry.list_models() | |
| loaded_models = registry.get_loaded_models() | |
| active_model = registry.get_active_model() | |
| gpu_available = False | |
| gpu_name = "N/A" | |
| try: | |
| import torch | |
| gpu_available = torch.cuda.is_available() | |
| gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A" | |
| except: | |
| pass | |
| return { | |
| "status": "ok", | |
| "available_models": len(models), | |
| "loaded_models": loaded_models, | |
| "active_local_model": active_model, | |
| "gpu_available": gpu_available, | |
| "gpu_device": gpu_name, | |
| } | |
| async def list_models(): | |
| """List all available models with their load status.""" | |
| return registry.list_models() | |
| async def load_model(model_name: str): | |
| """ | |
| Explicitly load a model into memory. | |
| For local models: unloads any previously loaded local model first. | |
| """ | |
| if model_name not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") | |
| try: | |
| info = await registry.load_model(model_name) | |
| return {"status": "loaded", "model": info} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") | |
| async def unload_model(model_name: str): | |
| """ | |
| Explicitly unload a model from memory to free resources. | |
| """ | |
| if model_name not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=404, detail=f"Unknown model: {model_name}") | |
| try: | |
| result = await registry.unload_model(model_name) | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to unload model: {str(e)}") | |
| async def enhance_description( | |
| domain: str = Body(..., embed=True), | |
| data: dict = Body(..., embed=True), | |
| model: str = Body("bielik-1.5b", embed=True), | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Generate an enhanced description using a single model. | |
| - **domain**: The name of the domain (e.g., 'cars'). | |
| - **data**: A dictionary with the data for the description. | |
| - **model**: Model to use (default: bielik-1.5b) | |
| """ | |
| start_time = time.time() | |
| # Validate model | |
| if model not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model}") | |
| # Load Domain Configuration | |
| domain_config = get_domain_config(domain) | |
| DomainSchema = domain_config["schema"] | |
| create_prompt = domain_config["create_prompt"] | |
| # Validate Input Data | |
| try: | |
| validated_data = DomainSchema(**data) | |
| except ValidationError as e: | |
| raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}") | |
| # Prompt Construction | |
| chat_messages = create_prompt(validated_data) | |
| # Text Generation | |
| try: | |
| llm = await registry.get_model(model) | |
| generated_description = await llm.generate( | |
| chat_messages=chat_messages, | |
| max_new_tokens=150, | |
| temperature=0.75, | |
| top_p=0.9, | |
| ) | |
| except Exception as e: | |
| print(f"Error during text generation with {model}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") | |
| generation_time = time.time() - start_time | |
| user_email = user['email'] if user else "anonymous" | |
| return EnhancedDescriptionResponse( | |
| description=generated_description, | |
| model_used=MODEL_CONFIG[model]["id"], | |
| generation_time=round(generation_time, 2), | |
| user_email=user_email | |
| ) | |
| async def compare_models( | |
| request: CompareRequest, | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Compare outputs from multiple models for the same input. | |
| Returns results from all specified models (or all available if not specified). | |
| """ | |
| total_start = time.time() | |
| # Get models to compare | |
| available_models = registry.get_available_model_names() | |
| models_to_use = request.models if request.models else available_models | |
| # Validate requested models | |
| for model in models_to_use: | |
| if model not in available_models: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model}") | |
| # Load Domain Configuration | |
| domain_config = get_domain_config(request.domain) | |
| DomainSchema = domain_config["schema"] | |
| create_prompt = domain_config["create_prompt"] | |
| # Validate Input Data | |
| try: | |
| validated_data = DomainSchema(**request.data) | |
| except ValidationError as e: | |
| raise HTTPException(status_code=422, detail=f"Invalid data: {e}") | |
| # Prompt Construction | |
| chat_messages = create_prompt(validated_data) | |
| # Generate with each model | |
| results = [] | |
| async def generate_with_model(model_name: str) -> ModelResult: | |
| start_time = time.time() | |
| try: | |
| llm = await registry.get_model(model_name) | |
| output = await llm.generate( | |
| chat_messages=chat_messages, | |
| max_new_tokens=150, | |
| temperature=0.75, | |
| top_p=0.9, | |
| ) | |
| return ModelResult( | |
| model=model_name, | |
| output=output, | |
| time=round(time.time() - start_time, 2), | |
| type=MODEL_CONFIG[model_name]["type"], | |
| error=None | |
| ) | |
| except Exception as e: | |
| return ModelResult( | |
| model=model_name, | |
| output="", | |
| time=round(time.time() - start_time, 2), | |
| type=MODEL_CONFIG[model_name]["type"], | |
| error=str(e) | |
| ) | |
| # Run all models (sequentially to avoid memory issues) | |
| for model_name in models_to_use: | |
| result = await generate_with_model(model_name) | |
| results.append(result) | |
| return CompareResponse( | |
| domain=request.domain, | |
| results=results, | |
| total_time=round(time.time() - total_start, 2) | |
| ) | |
| async def get_user_info(user: dict = Depends(get_authenticated_user)): | |
| """Get current authenticated user information""" | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| return { | |
| "user_id": user['user_id'], | |
| "email": user['email'], | |
| "name": user.get('name', 'Unknown') | |
| } | |
| # --- Batch Infill Endpoints --- | |
| async def batch_infill( | |
| request: InfillRequest, | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Batch gap-filling for ads using a single model. | |
| Accepts items with [GAP:n] markers or ___ and returns filled text | |
| with per-gap choices and alternatives. | |
| NOTE: For texts > 6000 chars, consider chunking (not yet implemented). | |
| """ | |
| print(f"DEBUG: Hit batch_infill endpoint with model={request.model}", flush=True) | |
| total_start = time.time() | |
| # Validate model | |
| if request.model not in registry.get_available_model_names(): | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}") | |
| # Load domain config for infill prompt | |
| domain_config = get_domain_config(request.domain) | |
| if "create_infill_prompt" not in domain_config: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Domain '{request.domain}' does not support infill operations" | |
| ) | |
| create_infill_prompt = domain_config["create_infill_prompt"] | |
| # Process each item | |
| results = [] | |
| error_count = 0 | |
| for item in request.items: | |
| result = await process_infill_item( | |
| item=item, | |
| model_name=request.model, | |
| options=request.options, | |
| create_infill_prompt=create_infill_prompt | |
| ) | |
| results.append(result) | |
| if result.status == "error": | |
| error_count += 1 | |
| return InfillResponse( | |
| model=request.model, | |
| results=results, | |
| total_time=round(time.time() - total_start, 2), | |
| processed_count=len(results), | |
| error_count=error_count | |
| ) | |
| async def compare_infill( | |
| request: CompareInfillRequest, | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Multi-model batch gap-filling comparison for A/B testing. | |
| Runs the same batch of items through multiple models and returns | |
| per-model results for comparison. | |
| """ | |
| total_start = time.time() | |
| # Get models to compare | |
| available_models = registry.get_available_model_names() | |
| models_to_use = request.models if request.models else available_models | |
| # Validate requested models | |
| for model in models_to_use: | |
| if model not in available_models: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {model}") | |
| # Load domain config | |
| domain_config = get_domain_config(request.domain) | |
| if "create_infill_prompt" not in domain_config: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Domain '{request.domain}' does not support infill operations" | |
| ) | |
| create_infill_prompt = domain_config["create_infill_prompt"] | |
| # Process with each model (sequentially for memory safety) | |
| model_results = [] | |
| for model_name in models_to_use: | |
| model_start = time.time() | |
| results = [] | |
| error_count = 0 | |
| for item in request.items: | |
| result = await process_infill_item( | |
| item=item, | |
| model_name=model_name, | |
| options=request.options, | |
| create_infill_prompt=create_infill_prompt | |
| ) | |
| results.append(result) | |
| if result.status == "error": | |
| error_count += 1 | |
| model_results.append(ModelInfillResult( | |
| model=model_name, | |
| type=MODEL_CONFIG[model_name]["type"], | |
| results=results, | |
| time=round(time.time() - model_start, 2), | |
| error_count=error_count | |
| )) | |
| return CompareInfillResponse( | |
| domain=request.domain, | |
| models=model_results, | |
| total_time=round(time.time() - total_start, 2) | |
| ) | |
| async def process_infill_item( | |
| item, | |
| model_name: str, | |
| options, | |
| create_infill_prompt | |
| ) -> InfillResult: | |
| """ | |
| Process a single infill item. | |
| Returns InfillResult with status, filled_text, and gaps. | |
| """ | |
| try: | |
| # Normalize gaps to [GAP:n] format | |
| normalized_text, gaps = normalize_gaps_to_tagged(item.text_with_gaps) | |
| if not gaps: | |
| # No gaps found, return original text | |
| return InfillResult( | |
| id=item.id, | |
| status="ok", | |
| filled_text=item.text_with_gaps, | |
| gaps=[], | |
| error=None | |
| ) | |
| # Build prompt | |
| if item.custom_messages: | |
| chat_messages = item.custom_messages | |
| use_grammar = False # Custom messages = plain text output expected | |
| else: | |
| chat_messages = create_infill_prompt(normalized_text, options, attributes=item.attributes) | |
| use_grammar = True # Standard prompt = use grammar for structured JSON | |
| # Generate with optional GBNF grammar constraint | |
| llm = await registry.get_model(model_name) | |
| grammar_str = None | |
| if use_grammar and hasattr(llm, 'llm') and llm.llm is not None: | |
| # Use model's default grammar (loaded from answers.gbnf) if available | |
| if hasattr(llm, 'default_grammar') and llm.default_grammar: | |
| grammar_str = llm.default_grammar | |
| print(f"DEBUG: Using model's default GBNF grammar", flush=True) | |
| else: | |
| # Fallback to dynamic grammar generation | |
| try: | |
| from app.logic.grammar_utils import get_infill_grammar | |
| grammar_str = get_infill_grammar(len(gaps)) | |
| print(f"DEBUG: Using dynamic GBNF grammar for {len(gaps)} gaps", flush=True) | |
| except ImportError: | |
| pass | |
| raw_output = await llm.generate( | |
| chat_messages=chat_messages, | |
| max_new_tokens=options.max_new_tokens, | |
| temperature=0.3 if use_grammar else options.temperature, # Lower temp with grammar | |
| top_p=0.9, | |
| grammar=grammar_str, | |
| ) | |
| # If custom_messages were provided, the output is plain text (not JSON) | |
| # Just return it directly as a single gap fill | |
| if item.custom_messages: | |
| # Clean up the raw output - strip whitespace, quotes, etc. | |
| choice = raw_output.strip().strip('"\'.,').strip() | |
| return InfillResult( | |
| id=item.id, | |
| status="ok", | |
| filled_text=choice, # The filled text is just the choice itself | |
| gaps=[GapFill(index=1, marker="[GAP:1]", choice=choice, alternatives=[])], | |
| error=None | |
| ) | |
| # Parse JSON from output (standard prompt format) | |
| parsed = parse_infill_response(raw_output) | |
| if not parsed: | |
| # JSON parsing failed | |
| return InfillResult( | |
| id=item.id, | |
| status="error", | |
| filled_text=None, | |
| gaps=[], | |
| error=f"Failed to parse JSON from model output: {raw_output[:200]}..." | |
| ) | |
| # Extract gaps and build result | |
| gap_fills = [] | |
| fills_dict = {} | |
| for gap_data in parsed.get("gaps", []): | |
| gap_fill = GapFill( | |
| index=gap_data.get("index", 0), | |
| marker=gap_data.get("marker", ""), | |
| choice=gap_data.get("choice", ""), | |
| alternatives=gap_data.get("alternatives", []) | |
| ) | |
| gap_fills.append(gap_fill) | |
| fills_dict[gap_fill.index] = gap_fill.choice | |
| # Get filled text - prefer model's version, fallback to reconstruction | |
| filled_text = parsed.get("filled_text") | |
| if not filled_text and fills_dict: | |
| filled_text = apply_fills(normalized_text, gaps, fills_dict) | |
| return InfillResult( | |
| id=item.id, | |
| status="ok", | |
| filled_text=filled_text, | |
| gaps=gap_fills, | |
| error=None | |
| ) | |
| except Exception as e: | |
| return InfillResult( | |
| id=item.id, | |
| status="error", | |
| filled_text=None, | |
| gaps=[], | |
| error=str(e) | |
| ) |