from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import logging from typing import List import os import uuid import torch # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="Phi-2 CPU Hosting API") # Model configuration MODEL_NAME = "microsoft/phi-2" # Force CPU usage and disable CUDA os.environ["CUDA_VISIBLE_DEVICES"] = "" torch.set_default_device("cpu") # Load model and tokenizer try: logger.info("Loading Phi-2 model and tokenizer...") # Explicitly set to CPU and float32 torch_dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch_dtype, trust_remote_code=True, device_map="cpu" ) logger.info("Phi-2 model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise # API key storage (in production, use a proper database) API_KEYS = {} # Request models class GenerationRequest(BaseModel): prompt: str max_length: int = 200 temperature: float = 0.7 top_p: float = 0.9 do_sample: bool = True class APIKeyRequest(BaseModel): name: str # Generation endpoint @app.post("/generate/{api_key}") async def generate_text(api_key: str, request: GenerationRequest): if api_key not in API_KEYS: raise HTTPException(status_code=401, detail="Invalid API key") try: inputs = tokenizer(request.prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_length=request.max_length, temperature=request.temperature, top_p=request.top_p, do_sample=request.do_sample, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Update usage count API_KEYS[api_key]["usage_count"] += 1 logger.info(f"Generated text for API key: {api_key}") return { "generated_text": generated_text, "usage_count": API_KEYS[api_key]["usage_count"] } except Exception as e: logger.error(f"Generation error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # API key management endpoints @app.post("/create_api_key") async def create_api_key(request: APIKeyRequest): new_key = str(uuid.uuid4()) API_KEYS[new_key] = { "name": request.name, "usage_count": 0 } logger.info(f"Created new API key for {request.name}") return {"api_key": new_key, "name": request.name} @app.get("/list_api_keys") async def list_api_keys(): return {"api_keys": API_KEYS} @app.delete("/revoke_api_key/{api_key}") async def revoke_api_key(api_key: str): if api_key in API_KEYS: del API_KEYS[api_key] logger.info(f"Revoked API key: {api_key}") return {"status": "success"} raise HTTPException(status_code=404, detail="API key not found") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)