| | import os |
| | import uuid |
| | import httpx |
| | import torch |
| | import logging |
| | import json |
| | import asyncio |
| | from typing import Dict, Optional |
| | from fastapi import FastAPI, Request, BackgroundTasks, HTTPException, Depends |
| | from fastapi.responses import JSONResponse |
| | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import uvicorn |
| | from contextlib import asynccontextmanager |
| |
|
| | |
| | MODEL_ID = "google/gemma-1.1-2b-it" |
| | HF_TOKEN = os.getenv("HF_TOKEN", "") |
| | API_KEY = os.getenv("API_KEY", "default-key-123") |
| | MAX_TOKENS = int(os.getenv("MAX_TOKENS", "450")) |
| | DEVICE = os.getenv("DEVICE", "cpu") |
| | PORT = int(os.getenv("PORT", "7860")) |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | security = HTTPBearer() |
| |
|
| | |
| | jobs: Dict[str, dict] = {} |
| |
|
| | class AIGenerator: |
| | def __init__(self): |
| | self.tokenizer = None |
| | self.model = None |
| | self.loaded = False |
| | self.load_error = None |
| | |
| | def load_model(self): |
| | """Load the AI model with authentication""" |
| | if self.loaded: |
| | return True |
| | |
| | logger.info(f"π Loading model: {MODEL_ID}") |
| | |
| | if not HF_TOKEN: |
| | logger.error("β HF_TOKEN is not set!") |
| | self.load_error = "HF_TOKEN environment variable is not set" |
| | return False |
| | |
| | try: |
| | |
| | logger.info("π₯ Loading tokenizer...") |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | MODEL_ID, |
| | token=HF_TOKEN |
| | ) |
| | |
| | |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | |
| | logger.info("β
Tokenizer loaded") |
| | |
| | |
| | logger.info("π₯ Loading model...") |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | token=HF_TOKEN, |
| | device_map=None |
| | ) |
| | |
| | |
| | self.model = self.model.to(DEVICE) |
| | self.model.eval() |
| | |
| | self.loaded = True |
| | logger.info("π Model loaded successfully!") |
| | return True |
| | |
| | except Exception as e: |
| | self.load_error = str(e) |
| | logger.error(f"β Model loading failed: {str(e)}") |
| | return False |
| |
|
| | |
| | generator = AIGenerator() |
| |
|
| | async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| | """Verify API key""" |
| | if credentials.credentials != API_KEY: |
| | raise HTTPException(status_code=401, detail="Invalid API key") |
| | return True |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """Lifespan manager - preload model on startup""" |
| | logger.info("π Starting AI API Server...") |
| | logger.info(f"π Config: Model={MODEL_ID}, Device={DEVICE}, MaxTokens={MAX_TOKENS}") |
| | |
| | |
| | try: |
| | generator.load_model() |
| | except Exception as e: |
| | logger.warning(f"Model preloading failed, will load on first request: {e}") |
| | |
| | yield |
| |
|
| | app = FastAPI(lifespan=lifespan) |
| |
|
| | def generate_text(prompt: str, max_tokens: int = None) -> str: |
| | """Generate text based on prompt""" |
| | try: |
| | if not generator.loaded: |
| | if not generator.load_model(): |
| | raise Exception(f"Model failed to load: {generator.load_error}") |
| | |
| | logger.info(f"π Generating text for prompt: '{prompt[:50]}...'") |
| | |
| | |
| | inputs = generator.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=512 |
| | ) |
| | |
| | inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = generator.model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens or MAX_TOKENS, |
| | do_sample=True, |
| | top_p=0.9, |
| | temperature=0.8, |
| | pad_token_id=generator.tokenizer.pad_token_id, |
| | repetition_penalty=1.1 |
| | ) |
| | |
| | |
| | generated_text = generator.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | if prompt in generated_text: |
| | generated_text = generated_text.replace(prompt, "").strip() |
| | |
| | logger.info(f"β
Generated {len(generated_text)} characters") |
| | return generated_text |
| | |
| | except Exception as e: |
| | logger.error(f"β Generation failed: {str(e)}") |
| | raise |
| |
|
| | @app.post("/api/generate-sync") |
| | async def generate_sync( |
| | request: Request, |
| | auth: bool = Depends(verify_api_key) |
| | ): |
| | """ |
| | Synchronous text generation |
| | Body: {"prompt": "your text", "max_tokens": 100} |
| | """ |
| | try: |
| | data = await request.json() |
| | |
| | if not data.get("prompt"): |
| | raise HTTPException(status_code=400, detail="Prompt is required") |
| | |
| | prompt = data["prompt"] |
| | max_tokens = data.get("max_tokens") |
| | |
| | logger.info(f"π₯ Sync request: '{prompt[:50]}...'") |
| | |
| | generated_text = generate_text(prompt, max_tokens) |
| | |
| | return JSONResponse({ |
| | "status": "success", |
| | "result": generated_text, |
| | "prompt": prompt, |
| | "text_length": len(generated_text), |
| | "model": MODEL_ID |
| | }) |
| | |
| | except Exception as e: |
| | logger.error(f"β Sync generation error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/api/generate") |
| | async def generate_async( |
| | request: Request, |
| | background_tasks: BackgroundTasks, |
| | auth: bool = Depends(verify_api_key) |
| | ): |
| | """ |
| | Asynchronous text generation (for longer tasks) |
| | Body: {"prompt": "your text", "max_tokens": 100, "callback_url": "optional"} |
| | """ |
| | try: |
| | data = await request.json() |
| | job_id = str(uuid.uuid4()) |
| | |
| | if not data.get("prompt"): |
| | raise HTTPException(status_code=400, detail="Prompt is required") |
| | |
| | prompt = data["prompt"] |
| | max_tokens = data.get("max_tokens") |
| | callback_url = data.get("callback_url") |
| | |
| | logger.info(f"π₯ Async request {job_id}") |
| | |
| | jobs[job_id] = { |
| | "status": "processing", |
| | "prompt": prompt |
| | } |
| | |
| | |
| | background_tasks.add_task( |
| | process_job_async, |
| | job_id, |
| | prompt, |
| | max_tokens, |
| | callback_url |
| | ) |
| | |
| | return JSONResponse({ |
| | "job_id": job_id, |
| | "status": "queued", |
| | "message": "Generation started", |
| | "model": MODEL_ID |
| | }) |
| | |
| | except Exception as e: |
| | logger.error(f"β Async request error: {str(e)}") |
| | raise HTTPException(status_code=400, detail=str(e)) |
| |
|
| | async def process_job_async(job_id: str, prompt: str, max_tokens: int = None, callback_url: str = None): |
| | """Background processing for async jobs""" |
| | try: |
| | logger.info(f"π Processing async job {job_id}") |
| | |
| | generated_text = generate_text(prompt, max_tokens) |
| | |
| | jobs[job_id] = { |
| | "status": "complete", |
| | "result": generated_text, |
| | "prompt": prompt, |
| | "text_length": len(generated_text) |
| | } |
| | |
| | logger.info(f"β
Completed async job {job_id}") |
| | |
| | |
| | if callback_url: |
| | try: |
| | async with httpx.AsyncClient(timeout=30.0) as client: |
| | await client.post( |
| | callback_url, |
| | json={ |
| | "job_id": job_id, |
| | "status": "complete", |
| | "result": generated_text, |
| | "prompt": prompt |
| | } |
| | ) |
| | except Exception as e: |
| | logger.error(f"β Callback failed: {e}") |
| | |
| | except Exception as e: |
| | error_msg = str(e) |
| | logger.error(f"β Async job {job_id} failed: {error_msg}") |
| | jobs[job_id] = { |
| | "status": "failed", |
| | "error": error_msg, |
| | "prompt": prompt |
| | } |
| |
|
| | @app.get("/api/status/{job_id}") |
| | async def get_status(job_id: str, auth: bool = Depends(verify_api_key)): |
| | """Check job status""" |
| | if job_id not in jobs: |
| | raise HTTPException(status_code=404, detail="Job not found") |
| | |
| | return JSONResponse(jobs[job_id]) |
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | """Health check endpoint""" |
| | return JSONResponse({ |
| | "status": "healthy", |
| | "model_loaded": generator.loaded, |
| | "model": MODEL_ID, |
| | "device": DEVICE, |
| | "max_tokens": MAX_TOKENS |
| | }) |
| |
|
| | @app.get("/model-info") |
| | async def model_info(): |
| | """Model information""" |
| | return JSONResponse({ |
| | "model": MODEL_ID, |
| | "loaded": generator.loaded, |
| | "error": generator.load_error, |
| | "device": DEVICE, |
| | "requires_auth": True, |
| | "token_available": bool(HF_TOKEN) |
| | }) |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Root endpoint""" |
| | return JSONResponse({ |
| | "message": "π€ AI Text Generation API", |
| | "version": "1.0", |
| | "model": MODEL_ID, |
| | "status": "operational" if generator.loaded else "model_loading", |
| | "endpoints": { |
| | "generate_sync": "POST /api/generate-sync", |
| | "generate_async": "POST /api/generate", |
| | "check_status": "GET /api/status/{job_id}", |
| | "health": "GET /health", |
| | "model_info": "GET /model-info" |
| | }, |
| | "usage": 'curl -X POST /api/generate-sync -H "Authorization: Bearer YOUR_KEY" -d \'{"prompt":"Hello"}\'' |
| | }) |
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run( |
| | app, |
| | host="0.0.0.0", |
| | port=PORT, |
| | log_level="info" |
| | ) |