File size: 3,412 Bytes
af32d57
155b5cb
af32d57
 
 
 
 
57995b2
b7105ef
af32d57
 
 
b7105ef
155b5cb
57995b2
af32d57
 
57995b2
b4623d2
 
 
 
af32d57
 
 
57995b2
 
b4623d2
 
57995b2
 
 
 
 
 
 
 
 
 
 
 
 
 
af32d57
 
 
 
 
 
155b5cb
af32d57
 
155b5cb
57995b2
af32d57
 
57995b2
af32d57
 
 
 
 
 
 
 
 
 
 
b4623d2
57995b2
 
 
 
 
 
 
 
 
 
 
af32d57
 
57995b2
 
af32d57
57995b2
 
 
 
 
af32d57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155b5cb
af32d57
 
 
 
 
 
 
b7105ef
155b5cb
af32d57
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
from typing import List
import os
import uuid
import torch

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(title="Phi-2 CPU Hosting API")

# Model configuration
MODEL_NAME = "microsoft/phi-2"

# Force CPU usage and disable CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = ""
torch.set_default_device("cpu")

# Load model and tokenizer
try:
    logger.info("Loading Phi-2 model and tokenizer...")
    
    # Explicitly set to CPU and float32
    torch_dtype = torch.float32
    
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        device_map="cpu"
    )
    
    logger.info("Phi-2 model loaded successfully!")
except Exception as e:
    logger.error(f"Failed to load model: {str(e)}")
    raise

# API key storage (in production, use a proper database)
API_KEYS = {}

# Request models
class GenerationRequest(BaseModel):
    prompt: str
    max_length: int = 200
    temperature: float = 0.7
    top_p: float = 0.9
    do_sample: bool = True

class APIKeyRequest(BaseModel):
    name: str

# Generation endpoint
@app.post("/generate/{api_key}")
async def generate_text(api_key: str, request: GenerationRequest):
    if api_key not in API_KEYS:
        raise HTTPException(status_code=401, detail="Invalid API key")
    
    try:
        inputs = tokenizer(request.prompt, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=request.max_length,
                temperature=request.temperature,
                top_p=request.top_p,
                do_sample=request.do_sample,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Update usage count
        API_KEYS[api_key]["usage_count"] += 1
        logger.info(f"Generated text for API key: {api_key}")
        
        return {
            "generated_text": generated_text,
            "usage_count": API_KEYS[api_key]["usage_count"]
        }
    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# API key management endpoints
@app.post("/create_api_key")
async def create_api_key(request: APIKeyRequest):
    new_key = str(uuid.uuid4())
    API_KEYS[new_key] = {
        "name": request.name,
        "usage_count": 0
    }
    logger.info(f"Created new API key for {request.name}")
    return {"api_key": new_key, "name": request.name}

@app.get("/list_api_keys")
async def list_api_keys():
    return {"api_keys": API_KEYS}

@app.delete("/revoke_api_key/{api_key}")
async def revoke_api_key(api_key: str):
    if api_key in API_KEYS:
        del API_KEYS[api_key]
        logger.info(f"Revoked API key: {api_key}")
        return {"status": "success"}
    raise HTTPException(status_code=404, detail="API key not found")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)