kripeshAlt commited on
Commit
af32d57
·
verified ·
1 Parent(s): a8a012c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -19
app.py CHANGED
@@ -1,28 +1,92 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from fastapi import FastAPI
4
  from pydantic import BaseModel
5
- import uvicorn
 
 
 
 
6
 
7
- # Load model and tokenizer
8
- tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-vl-1.3b-chat")
9
- model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-vl-1.3b-chat")
10
 
11
  # Initialize FastAPI app
12
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Define input schema
15
- class RequestBody(BaseModel):
16
  prompt: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Define the model inference function
19
- @app.post("/predict")
20
- async def predict(request: RequestBody):
21
- inputs = tokenizer(request.prompt, return_tensors="pt")
22
- outputs = model.generate(**inputs, max_new_tokens=100)
23
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- return {"response": result}
25
 
26
- # For testing locally (not needed for Hugging Face)
27
  if __name__ == "__main__":
28
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
1
+ from fastapi import FastAPI, HTTPException
 
 
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import logging
5
+ from typing import List
6
+ import os
7
+ import uuid
8
 
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  # Initialize FastAPI app
14
+ app = FastAPI(title="DeepSeek CPU Hosting API")
15
+
16
+ # Model configuration
17
+ MODEL_NAME = "deepseek-ai/deepseek-llm-7b" # Example model, replace with actual DeepSeek model
18
+ DEVICE = "cpu" # Force CPU usage
19
+
20
+ # Load model and tokenizer
21
+ try:
22
+ logger.info("Loading model and tokenizer...")
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
25
+ model.to(DEVICE)
26
+ logger.info("Model loaded successfully!")
27
+ except Exception as e:
28
+ logger.error(f"Failed to load model: {str(e)}")
29
+ raise
30
+
31
+ # API key storage (in production, use a proper database)
32
+ API_KEYS = {}
33
 
34
+ # Request models
35
+ class GenerationRequest(BaseModel):
36
  prompt: str
37
+ max_length: int = 100
38
+ temperature: float = 0.7
39
+ top_p: float = 0.9
40
+
41
+ class APIKeyRequest(BaseModel):
42
+ name: str
43
+
44
+ # Generation endpoint
45
+ @app.post("/generate/{api_key}")
46
+ async def generate_text(api_key: str, request: GenerationRequest):
47
+ if api_key not in API_KEYS:
48
+ raise HTTPException(status_code=401, detail="Invalid API key")
49
+
50
+ try:
51
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(DEVICE)
52
+ outputs = model.generate(
53
+ **inputs,
54
+ max_length=request.max_length,
55
+ temperature=request.temperature,
56
+ top_p=request.top_p,
57
+ do_sample=True
58
+ )
59
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+
61
+ logger.info(f"Generated text for API key: {api_key}")
62
+ return {"generated_text": generated_text}
63
+ except Exception as e:
64
+ logger.error(f"Generation error: {str(e)}")
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+ # API key management endpoints
68
+ @app.post("/create_api_key")
69
+ async def create_api_key(request: APIKeyRequest):
70
+ new_key = str(uuid.uuid4())
71
+ API_KEYS[new_key] = {
72
+ "name": request.name,
73
+ "usage_count": 0
74
+ }
75
+ logger.info(f"Created new API key for {request.name}")
76
+ return {"api_key": new_key, "name": request.name}
77
+
78
+ @app.get("/list_api_keys")
79
+ async def list_api_keys():
80
+ return {"api_keys": API_KEYS}
81
 
82
+ @app.delete("/revoke_api_key/{api_key}")
83
+ async def revoke_api_key(api_key: str):
84
+ if api_key in API_KEYS:
85
+ del API_KEYS[api_key]
86
+ logger.info(f"Revoked API key: {api_key}")
87
+ return {"status": "success"}
88
+ raise HTTPException(status_code=404, detail="API key not found")
89
 
 
90
  if __name__ == "__main__":
91
+ import uvicorn
92
+ uvicorn.run(app, host="0.0.0.0", port=8000)