Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import logging | |
| from typing import List | |
| import os | |
| import uuid | |
| import torch | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Phi-2 CPU Hosting API") | |
| # Model configuration | |
| MODEL_NAME = "microsoft/phi-2" | |
| # Force CPU usage and disable CUDA | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| torch.set_default_device("cpu") | |
| # Load model and tokenizer | |
| try: | |
| logger.info("Loading Phi-2 model and tokenizer...") | |
| # Explicitly set to CPU and float32 | |
| torch_dtype = torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True, | |
| device_map="cpu" | |
| ) | |
| logger.info("Phi-2 model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| # API key storage (in production, use a proper database) | |
| API_KEYS = {} | |
| # Request models | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_length: int = 200 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| do_sample: bool = True | |
| class APIKeyRequest(BaseModel): | |
| name: str | |
| # Generation endpoint | |
| async def generate_text(api_key: str, request: GenerationRequest): | |
| if api_key not in API_KEYS: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| try: | |
| inputs = tokenizer(request.prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=request.max_length, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| do_sample=request.do_sample, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Update usage count | |
| API_KEYS[api_key]["usage_count"] += 1 | |
| logger.info(f"Generated text for API key: {api_key}") | |
| return { | |
| "generated_text": generated_text, | |
| "usage_count": API_KEYS[api_key]["usage_count"] | |
| } | |
| except Exception as e: | |
| logger.error(f"Generation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # API key management endpoints | |
| async def create_api_key(request: APIKeyRequest): | |
| new_key = str(uuid.uuid4()) | |
| API_KEYS[new_key] = { | |
| "name": request.name, | |
| "usage_count": 0 | |
| } | |
| logger.info(f"Created new API key for {request.name}") | |
| return {"api_key": new_key, "name": request.name} | |
| async def list_api_keys(): | |
| return {"api_keys": API_KEYS} | |
| async def revoke_api_key(api_key: str): | |
| if api_key in API_KEYS: | |
| del API_KEYS[api_key] | |
| logger.info(f"Revoked API key: {api_key}") | |
| return {"status": "success"} | |
| raise HTTPException(status_code=404, detail="API key not found") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |