kripAI / app.py
kripeshAlt's picture
Update app.py
b4623d2 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
from typing import List
import os
import uuid
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="Phi-2 CPU Hosting API")
# Model configuration
MODEL_NAME = "microsoft/phi-2"
# Force CPU usage and disable CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_default_device("cpu")
# Load model and tokenizer
try:
logger.info("Loading Phi-2 model and tokenizer...")
# Explicitly set to CPU and float32
torch_dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch_dtype,
trust_remote_code=True,
device_map="cpu"
)
logger.info("Phi-2 model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
# API key storage (in production, use a proper database)
API_KEYS = {}
# Request models
class GenerationRequest(BaseModel):
prompt: str
max_length: int = 200
temperature: float = 0.7
top_p: float = 0.9
do_sample: bool = True
class APIKeyRequest(BaseModel):
name: str
# Generation endpoint
@app.post("/generate/{api_key}")
async def generate_text(api_key: str, request: GenerationRequest):
if api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API key")
try:
inputs = tokenizer(request.prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p,
do_sample=request.do_sample,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update usage count
API_KEYS[api_key]["usage_count"] += 1
logger.info(f"Generated text for API key: {api_key}")
return {
"generated_text": generated_text,
"usage_count": API_KEYS[api_key]["usage_count"]
}
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# API key management endpoints
@app.post("/create_api_key")
async def create_api_key(request: APIKeyRequest):
new_key = str(uuid.uuid4())
API_KEYS[new_key] = {
"name": request.name,
"usage_count": 0
}
logger.info(f"Created new API key for {request.name}")
return {"api_key": new_key, "name": request.name}
@app.get("/list_api_keys")
async def list_api_keys():
return {"api_keys": API_KEYS}
@app.delete("/revoke_api_key/{api_key}")
async def revoke_api_key(api_key: str):
if api_key in API_KEYS:
del API_KEYS[api_key]
logger.info(f"Revoked API key: {api_key}")
return {"status": "success"}
raise HTTPException(status_code=404, detail="API key not found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)