Spaces:
Running
Running
| import os | |
| import time | |
| import importlib | |
| from fastapi import FastAPI, HTTPException, Depends, Body | |
| from typing import Optional | |
| from pydantic import ValidationError | |
| from app.models.huggingface_service import HuggingFaceTextGenerationService | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from app.schemas.schemas import EnhancedDescriptionResponse | |
| from app.auth.placeholder_auth import get_authenticated_user | |
| # MCP imports removed | |
| app = FastAPI( | |
| title="Modular Car Description Enhancer", | |
| description="AI-powered service for enhancing descriptions for multiple domains with Auth0 JWT authentication", | |
| version="2.0.0" | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:5173", | |
| "http://localhost:5174", | |
| os.getenv("FRONTEND_URL", "http://localhost:5173") | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["*"], | |
| ) | |
| # Global service initialization | |
| MODEL_PATH_IN_CONTAINER = "/app/pretrain_model" | |
| hf_service = HuggingFaceTextGenerationService( | |
| model_name_or_path=MODEL_PATH_IN_CONTAINER, | |
| device="cpu" | |
| ) | |
| async def startup_event(): | |
| print("Starting up and initializing HuggingFace service...") | |
| try: | |
| await hf_service.initialize() | |
| print(f"HuggingFace service initialized successfully from {MODEL_PATH_IN_CONTAINER}.") | |
| except Exception as e: | |
| print(f"An unexpected error occurred during HuggingFace service initialization: {e}") | |
| raise | |
| # --- Helper function to load domain logic --- | |
| def get_domain_config(domain: str): | |
| try: | |
| module = importlib.import_module(f"app.domains.{domain}.config") | |
| return module.domain_config | |
| except (ImportError, AttributeError): | |
| raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.") | |
| # --- API Endpoints --- | |
| async def read_root(): | |
| return {"message": "Welcome to the Modular Description Enhancer API! Go to /docs for documentation."} | |
| async def health_check(): | |
| return { | |
| "status": "ok", | |
| "model_initialized": hf_service.pipeline is not None, | |
| } | |
| async def enhance_description( | |
| domain: str = Body(..., embed=True), | |
| data: dict = Body(..., embed=True), | |
| user: Optional[dict] = Depends(get_authenticated_user) | |
| ): | |
| """ | |
| Generate an enhanced description for a given domain and data. | |
| - **domain**: The name of the domain (e.g., 'cars'). | |
| - **data**: A dictionary with the data for the description. | |
| """ | |
| start_time = time.time() | |
| # --- 1. Load Domain Configuration --- | |
| domain_config = get_domain_config(domain) | |
| DomainSchema = domain_config["schema"] | |
| create_prompt = domain_config["create_prompt"] | |
| # mcp_rules removed | |
| # --- 2. Validate Input Data --- | |
| try: | |
| validated_data = DomainSchema(**data) | |
| except ValidationError as e: | |
| raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}") | |
| # --- 3. Prompt Construction --- | |
| chat_messages = create_prompt(validated_data) | |
| # --- 4. Text Generation --- | |
| try: | |
| generated_description = await hf_service.generate_text( | |
| chat_template_messages=chat_messages, | |
| max_new_tokens=150, | |
| temperature=0.75, | |
| top_p=0.9, | |
| ) | |
| except Exception as e: | |
| print(f"Unexpected error during text generation: {e}") | |
| raise HTTPException(status_code=500, detail=f"An unexpected error occurred during text generation: {str(e)}") | |
| # --- 5. MCP Guardrails & Post-processing removed --- | |
| # if not guardrails.check_compliance(generated_description, mcp_rules.get("guardrails", {})): | |
| # raise HTTPException(status_code=400, detail="Generated description failed compliance checks.") | |
| # final_description = postprocessor.format_output(generated_description, mcp_rules.get("postprocessor", {})) | |
| final_description = generated_description # No post-processing here | |
| generation_time = time.time() - start_time | |
| user_email = user['email'] if user else "anonymous" | |
| return EnhancedDescriptionResponse( | |
| description=final_description, | |
| model_used="speakleash/Bielik-1.5B-v3.0-Instruct", | |
| generation_time=round(generation_time, 2), | |
| user_email=user_email | |
| ) | |
| async def generate_text_only( | |
| chat_template_messages: str = Body(..., embed=True), | |
| max_new_tokens: int = 150, | |
| temperature: float = 0.75, | |
| top_p: float = 0.9 | |
| ): | |
| """ | |
| Generates raw text based on provided chat template messages. | |
| This endpoint is intended for internal use by the MCP service. | |
| """ | |
| try: | |
| generated_text = await hf_service.generate_text( | |
| chat_template_messages=chat_template_messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| return {"generated_text": generated_text} | |
| except Exception as e: | |
| print(f"Unexpected error during raw text generation: {e}") | |
| raise HTTPException(status_code=500, detail=f"An unexpected error occurred during text generation: {str(e)}") | |
| async def get_user_info(user: dict = Depends(get_authenticated_user)): | |
| """Get current authenticated user information""" | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| return { | |
| "user_id": user['user_id'], | |
| "email": user['email'], | |
| "name": user.get('name', 'Unknown') | |
| } |