import os import uuid import httpx import torch import logging import json import asyncio from typing import Dict, Optional from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends from fastapi.responses import JSONResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from transformers import AutoTokenizer, AutoModelForCausalLM import uvicorn from contextlib import asynccontextmanager # Configuration - NOW WORKING! MODEL_ID = "google/gemma-1.1-2b-it" HF_TOKEN = os.getenv("HF_TOKEN", "") API_KEY = os.getenv("API_KEY", "default-key-123") MAX_TOKENS = int(os.getenv("MAX_TOKENS", "450")) DEVICE = os.getenv("DEVICE", "cpu") PORT = int(os.getenv("PORT", "7860")) # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Security security = HTTPBearer() # Job storage jobs: Dict[str, dict] = {} class AIGenerator: def __init__(self): self.tokenizer = None self.model = None self.loaded = False self.load_error = None def load_model(self): """Load the AI model with authentication""" if self.loaded: return True logger.info(f"🚀 Loading model: {MODEL_ID}") if not HF_TOKEN: logger.error("❌ HF_TOKEN is not set!") self.load_error = "HF_TOKEN environment variable is not set" return False try: # Load tokenizer with authentication logger.info("📥 Loading tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, token=HF_TOKEN # Key change: use 'token' parameter ) # Set padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("✅ Tokenizer loaded") # Load model with authentication logger.info("📥 Loading model...") self.model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, token=HF_TOKEN, # Key change: use 'token' parameter device_map=None ) # Move to device self.model = self.model.to(DEVICE) self.model.eval() self.loaded = True logger.info("🎉 Model loaded successfully!") return True except Exception as e: self.load_error = str(e) logger.error(f"❌ Model loading failed: {str(e)}") return False # Global generator instance generator = AIGenerator() async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify API key""" if credentials.credentials != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") return True @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan manager - preload model on startup""" logger.info("🚀 Starting AI API Server...") logger.info(f"📊 Config: Model={MODEL_ID}, Device={DEVICE}, MaxTokens={MAX_TOKENS}") # Try to preload model (non-blocking) try: generator.load_model() except Exception as e: logger.warning(f"Model preloading failed, will load on first request: {e}") yield app = FastAPI(lifespan=lifespan) def generate_text(prompt: str, max_tokens: int = None) -> str: """Generate text based on prompt""" try: if not generator.loaded: if not generator.load_model(): raise Exception(f"Model failed to load: {generator.load_error}") logger.info(f"📝 Generating text for prompt: '{prompt[:50]}...'") # Tokenize inputs = generator.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=512 ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = generator.model.generate( **inputs, max_new_tokens=max_tokens or MAX_TOKENS, do_sample=True, top_p=0.9, temperature=0.8, pad_token_id=generator.tokenizer.pad_token_id, repetition_penalty=1.1 ) # Decode generated_text = generator.tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove prompt if included if prompt in generated_text: generated_text = generated_text.replace(prompt, "").strip() logger.info(f"✅ Generated {len(generated_text)} characters") return generated_text except Exception as e: logger.error(f"❌ Generation failed: {str(e)}") raise @app.post("/api/generate-sync") async def generate_sync( request: Request, auth: bool = Depends(verify_api_key) ): """ Synchronous text generation Body: {"prompt": "your text", "max_tokens": 100} """ try: data = await request.json() if not data.get("prompt"): raise HTTPException(status_code=400, detail="Prompt is required") prompt = data["prompt"] max_tokens = data.get("max_tokens") logger.info(f"📥 Sync request: '{prompt[:50]}...'") generated_text = generate_text(prompt, max_tokens) return JSONResponse({ "status": "success", "result": generated_text, "prompt": prompt, "text_length": len(generated_text), "model": MODEL_ID }) except Exception as e: logger.error(f"❌ Sync generation error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/generate") async def generate_async( request: Request, background_tasks: BackgroundTasks, auth: bool = Depends(verify_api_key) ): """ Asynchronous text generation (for longer tasks) Body: {"prompt": "your text", "max_tokens": 100, "callback_url": "optional"} """ try: data = await request.json() job_id = str(uuid.uuid4()) if not data.get("prompt"): raise HTTPException(status_code=400, detail="Prompt is required") prompt = data["prompt"] max_tokens = data.get("max_tokens") callback_url = data.get("callback_url") logger.info(f"📥 Async request {job_id}") jobs[job_id] = { "status": "processing", "prompt": prompt } # Process in background background_tasks.add_task( process_job_async, job_id, prompt, max_tokens, callback_url ) return JSONResponse({ "job_id": job_id, "status": "queued", "message": "Generation started", "model": MODEL_ID }) except Exception as e: logger.error(f"❌ Async request error: {str(e)}") raise HTTPException(status_code=400, detail=str(e)) async def process_job_async(job_id: str, prompt: str, max_tokens: int = None, callback_url: str = None): """Background processing for async jobs""" try: logger.info(f"🔄 Processing async job {job_id}") generated_text = generate_text(prompt, max_tokens) jobs[job_id] = { "status": "complete", "result": generated_text, "prompt": prompt, "text_length": len(generated_text) } logger.info(f"✅ Completed async job {job_id}") # Send callback if provided if callback_url: try: async with httpx.AsyncClient(timeout=30.0) as client: await client.post( callback_url, json={ "job_id": job_id, "status": "complete", "result": generated_text, "prompt": prompt } ) except Exception as e: logger.error(f"❌ Callback failed: {e}") except Exception as e: error_msg = str(e) logger.error(f"❌ Async job {job_id} failed: {error_msg}") jobs[job_id] = { "status": "failed", "error": error_msg, "prompt": prompt } @app.get("/api/status/{job_id}") async def get_status(job_id: str, auth: bool = Depends(verify_api_key)): """Check job status""" if job_id not in jobs: raise HTTPException(status_code=404, detail="Job not found") return JSONResponse(jobs[job_id]) @app.get("/health") async def health_check(): """Health check endpoint""" return JSONResponse({ "status": "healthy", "model_loaded": generator.loaded, "model": MODEL_ID, "device": DEVICE, "max_tokens": MAX_TOKENS }) @app.get("/model-info") async def model_info(): """Model information""" return JSONResponse({ "model": MODEL_ID, "loaded": generator.loaded, "error": generator.load_error, "device": DEVICE, "requires_auth": True, "token_available": bool(HF_TOKEN) }) @app.get("/") async def root(): """Root endpoint""" return JSONResponse({ "message": "🤖 AI Text Generation API", "version": "1.0", "model": MODEL_ID, "status": "operational" if generator.loaded else "model_loading", "endpoints": { "generate_sync": "POST /api/generate-sync", "generate_async": "POST /api/generate", "check_status": "GET /api/status/{job_id}", "health": "GET /health", "model_info": "GET /model-info" }, "usage": 'curl -X POST /api/generate-sync -H "Authorization: Bearer YOUR_KEY" -d \'{"prompt":"Hello"}\'' }) if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=PORT, log_level="info" )