Prithvik-1 commited on
Commit
91c60e6
·
verified ·
1 Parent(s): ceb778d

Upload models/msp/api/api_server_secure.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/msp/api/api_server_secure.py +194 -0
models/msp/api/api_server_secure.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Secure FastAPI server for serving Mistral 7B fine-tuned models
4
+ Includes API key authentication like commercial services
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import secrets
10
+ from typing import Optional
11
+ from fastapi import FastAPI, HTTPException, Header, Depends
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
14
+ from pydantic import BaseModel
15
+ import uvicorn
16
+ from pathlib import Path
17
+ sys.path.insert(0, str(Path(__file__).parent.parent))
18
+ from inference.inference_mistral7b import load_local_model, generate_with_local_model, get_device_info
19
+ import torch
20
+
21
+ # Configuration - Resolve model path relative to msp root
22
+ _MODEL_BASE = Path(__file__).parent.parent / "mistral7b-finetuned-ahb2apb"
23
+ DEFAULT_MODEL_PATH = str(_MODEL_BASE)
24
+
25
+ # API Key authentication
26
+ API_KEYS = set()
27
+ API_KEY_FILE = "api_keys.txt"
28
+
29
+ # Load or generate API keys
30
+ def load_api_keys():
31
+ """Load API keys from file or create default"""
32
+ global API_KEYS
33
+
34
+ if os.path.exists(API_KEY_FILE):
35
+ with open(API_KEY_FILE, 'r') as f:
36
+ API_KEYS = {line.strip() for line in f if line.strip()}
37
+ else:
38
+ # Generate a default API key
39
+ default_key = secrets.token_urlsafe(32)
40
+ with open(API_KEY_FILE, 'w') as f:
41
+ f.write(default_key + '\n')
42
+ API_KEYS = {default_key}
43
+ print(f"\n🔑 Generated default API key: {default_key}")
44
+ print(f" Save this key! Store it in: {API_KEY_FILE}")
45
+
46
+ print(f"✓ Loaded {len(API_KEYS)} API key(s)")
47
+
48
+ def verify_api_key(api_key: str = Header(None)):
49
+ """Verify API key in request header"""
50
+ if api_key is None:
51
+ raise HTTPException(
52
+ status_code=401,
53
+ detail="API key required. Add header: 'X-API-Key: your-api-key'"
54
+ )
55
+ if api_key not in API_KEYS:
56
+ raise HTTPException(status_code=403, detail="Invalid API key")
57
+ return api_key
58
+
59
+ # Global model and tokenizer
60
+ model = None
61
+ tokenizer = None
62
+ device_info = None
63
+
64
+ app = FastAPI(
65
+ title="Mistral 7B AHB2APB API (Secure)",
66
+ description="Secure API for serving the fine-tuned Mistral 7B model for AHB2APB conversion",
67
+ version="1.0.0"
68
+ )
69
+
70
+ # Enable CORS
71
+ app.add_middleware(
72
+ CORSMiddleware,
73
+ allow_origins=["*"], # In production, restrict this!
74
+ allow_credentials=True,
75
+ allow_methods=["*"],
76
+ allow_headers=["*"],
77
+ )
78
+
79
+ # Security scheme
80
+ security = HTTPBearer(auto_error=False)
81
+
82
+ # Request/Response models
83
+ class GenerateRequest(BaseModel):
84
+ prompt: str
85
+ max_length: Optional[int] = 512
86
+ temperature: Optional[float] = 0.7
87
+
88
+ class GenerateResponse(BaseModel):
89
+ response: str
90
+ model: str
91
+ max_length: int
92
+ temperature: float
93
+
94
+ class HealthResponse(BaseModel):
95
+ status: str
96
+ model_loaded: bool
97
+ device: str
98
+ model_path: str
99
+ authentication: str
100
+
101
+ @app.on_event("startup")
102
+ async def startup():
103
+ """Load model and API keys on startup"""
104
+ global model, tokenizer, device_info
105
+ load_api_keys()
106
+
107
+ model_path = os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH)
108
+
109
+ print(f"\nLoading model from: {model_path}")
110
+ print("=" * 70)
111
+
112
+ try:
113
+ device_info = get_device_info()
114
+ model, tokenizer = load_local_model(model_path)
115
+ print(f"\n✓ Model loaded successfully on {device_info['device']}!")
116
+ print(f"✓ API server ready (authentication enabled)")
117
+ print("=" * 70)
118
+ except Exception as e:
119
+ print(f"\n✗ Error loading model: {e}")
120
+ sys.exit(1)
121
+
122
+ @app.get("/health", response_model=HealthResponse)
123
+ async def health_check():
124
+ """Health check endpoint (no auth required)"""
125
+ return HealthResponse(
126
+ status="healthy" if model is not None else "error",
127
+ model_loaded=model is not None,
128
+ device=device_info["device"] if device_info else "unknown",
129
+ model_path=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH),
130
+ authentication="enabled"
131
+ )
132
+
133
+ @app.get("/")
134
+ async def root():
135
+ """Root endpoint with API information"""
136
+ return {
137
+ "name": "Mistral 7B AHB2APB API (Secure)",
138
+ "version": "1.0.0",
139
+ "status": "running",
140
+ "authentication": "API key required",
141
+ "model": os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH),
142
+ "endpoints": {
143
+ "health": "/health",
144
+ "generate": "/api/generate (requires API key)",
145
+ "docs": "/docs"
146
+ }
147
+ }
148
+
149
+ @app.post("/api/generate", response_model=GenerateResponse)
150
+ async def generate(request: GenerateRequest, api_key: str = Depends(verify_api_key)):
151
+ """Generate text from a prompt (requires API key)"""
152
+ if model is None or tokenizer is None:
153
+ raise HTTPException(status_code=503, detail="Model not loaded")
154
+
155
+ try:
156
+ response = generate_with_local_model(
157
+ model=model,
158
+ tokenizer=tokenizer,
159
+ prompt=request.prompt,
160
+ max_length=request.max_length or 512,
161
+ temperature=request.temperature or 0.7
162
+ )
163
+
164
+ return GenerateResponse(
165
+ response=response,
166
+ model=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH),
167
+ max_length=request.max_length or 512,
168
+ temperature=request.temperature or 0.7
169
+ )
170
+ except Exception as e:
171
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
172
+
173
+ if __name__ == "__main__":
174
+ import argparse
175
+
176
+ parser = argparse.ArgumentParser(description="Start Secure Mistral 7B API server")
177
+ parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH)
178
+ parser.add_argument("--host", type=str, default="0.0.0.0")
179
+ parser.add_argument("--port", type=int, default=8000)
180
+ parser.add_argument("--reload", action="store_true")
181
+
182
+ args = parser.parse_args()
183
+ os.environ["MODEL_PATH"] = args.model_path
184
+
185
+ print(f"\n🔒 Starting Secure Mistral 7B AHB2APB API Server")
186
+ print(f" Model: {args.model_path}")
187
+ print(f" Host: {args.host}")
188
+ print(f" Port: {args.port}\n")
189
+
190
+ # Change to api directory for proper module resolution
191
+ import os
192
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
193
+ uvicorn.run("api_server_secure:app", host=args.host, port=args.port, reload=args.reload)
194
+