Patryk Studzinski
Fix: correct parameter name model_name_or_path
a1c0774
raw
history blame
5.67 kB
import os
import time
import importlib
from fastapi import FastAPI, HTTPException, Depends, Body
from typing import Optional
from pydantic import ValidationError
from app.models.huggingface_service import HuggingFaceTextGenerationService
from fastapi.middleware.cors import CORSMiddleware
from app.schemas.schemas import EnhancedDescriptionResponse
from app.auth.placeholder_auth import get_authenticated_user
# MCP imports removed
app = FastAPI(
title="Modular Car Description Enhancer",
description="AI-powered service for enhancing descriptions for multiple domains with Auth0 JWT authentication",
version="2.0.0"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"http://localhost:5174",
os.getenv("FRONTEND_URL", "http://localhost:5173")
],
allow_credentials=True,
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
# Global service initialization
MODEL_PATH_IN_CONTAINER = "/app/pretrain_model"
hf_service = HuggingFaceTextGenerationService(
model_name_or_path=MODEL_PATH_IN_CONTAINER,
device="cpu"
)
@app.on_event("startup")
async def startup_event():
print("Starting up and initializing HuggingFace service...")
try:
await hf_service.initialize()
print(f"HuggingFace service initialized successfully from {MODEL_PATH_IN_CONTAINER}.")
except Exception as e:
print(f"An unexpected error occurred during HuggingFace service initialization: {e}")
raise
# --- Helper function to load domain logic ---
def get_domain_config(domain: str):
try:
module = importlib.import_module(f"app.domains.{domain}.config")
return module.domain_config
except (ImportError, AttributeError):
raise HTTPException(status_code=404, detail=f"Domain '{domain}' not found or not configured correctly.")
# --- API Endpoints ---
@app.get("/")
async def read_root():
return {"message": "Welcome to the Modular Description Enhancer API! Go to /docs for documentation."}
@app.get("/health")
async def health_check():
return {
"status": "ok",
"model_initialized": hf_service.pipeline is not None,
}
@app.post("/enhance-description", response_model=EnhancedDescriptionResponse)
async def enhance_description(
domain: str = Body(..., embed=True),
data: dict = Body(..., embed=True),
user: Optional[dict] = Depends(get_authenticated_user)
):
"""
Generate an enhanced description for a given domain and data.
- **domain**: The name of the domain (e.g., 'cars').
- **data**: A dictionary with the data for the description.
"""
start_time = time.time()
# --- 1. Load Domain Configuration ---
domain_config = get_domain_config(domain)
DomainSchema = domain_config["schema"]
create_prompt = domain_config["create_prompt"]
# mcp_rules removed
# --- 2. Validate Input Data ---
try:
validated_data = DomainSchema(**data)
except ValidationError as e:
raise HTTPException(status_code=422, detail=f"Invalid data for domain '{domain}': {e}")
# --- 3. Prompt Construction ---
chat_messages = create_prompt(validated_data)
# --- 4. Text Generation ---
try:
generated_description = await hf_service.generate_text(
chat_template_messages=chat_messages,
max_new_tokens=150,
temperature=0.75,
top_p=0.9,
)
except Exception as e:
print(f"Unexpected error during text generation: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during text generation: {str(e)}")
# --- 5. MCP Guardrails & Post-processing removed ---
# if not guardrails.check_compliance(generated_description, mcp_rules.get("guardrails", {})):
# raise HTTPException(status_code=400, detail="Generated description failed compliance checks.")
# final_description = postprocessor.format_output(generated_description, mcp_rules.get("postprocessor", {}))
final_description = generated_description # No post-processing here
generation_time = time.time() - start_time
user_email = user['email'] if user else "anonymous"
return EnhancedDescriptionResponse(
description=final_description,
model_used="speakleash/Bielik-1.5B-v3.0-Instruct",
generation_time=round(generation_time, 2),
user_email=user_email
)
@app.post("/generate")
async def generate_text_only(
chat_template_messages: str = Body(..., embed=True),
max_new_tokens: int = 150,
temperature: float = 0.75,
top_p: float = 0.9
):
"""
Generates raw text based on provided chat template messages.
This endpoint is intended for internal use by the MCP service.
"""
try:
generated_text = await hf_service.generate_text(
chat_template_messages=chat_template_messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
)
return {"generated_text": generated_text}
except Exception as e:
print(f"Unexpected error during raw text generation: {e}")
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during text generation: {str(e)}")
@app.get("/user/me")
async def get_user_info(user: dict = Depends(get_authenticated_user)):
"""Get current authenticated user information"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
return {
"user_id": user['user_id'],
"email": user['email'],
"name": user.get('name', 'Unknown')
}