Prithvik-1 commited on
Commit
ceb778d
·
verified ·
1 Parent(s): d3bf1f9

Upload models/msp/api/api_server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/msp/api/api_server.py +216 -0
models/msp/api/api_server.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FastAPI server for serving Mistral 7B fine-tuned models
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ from typing import Optional, Dict, Any
9
+ from fastapi import FastAPI, HTTPException
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ import uvicorn
13
+ import sys
14
+ from pathlib import Path
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+ from inference.inference_mistral7b import load_local_model, generate_with_local_model, get_device_info
17
+ import torch
18
+
19
+ # Configuration - Resolve model path relative to msp root
20
+ _MODEL_BASE = Path(__file__).parent.parent / "mistral7b-finetuned-ahb2apb"
21
+ DEFAULT_MODEL_PATH = str(_MODEL_BASE)
22
+
23
+ # Global model and tokenizer (loaded once at startup)
24
+ model = None
25
+ tokenizer = None
26
+ device_info = None
27
+
28
+ app = FastAPI(
29
+ title="Mistral 7B AHB2APB API",
30
+ description="API for serving the fine-tuned Mistral 7B model for AHB2APB conversion",
31
+ version="1.0.0"
32
+ )
33
+
34
+ # Enable CORS
35
+ app.add_middleware(
36
+ CORSMiddleware,
37
+ allow_origins=["*"],
38
+ allow_credentials=True,
39
+ allow_methods=["*"],
40
+ allow_headers=["*"],
41
+ )
42
+
43
+ # Request/Response models
44
+ class GenerateRequest(BaseModel):
45
+ prompt: str
46
+ max_length: Optional[int] = 512
47
+ temperature: Optional[float] = 0.7
48
+
49
+ class GenerateResponse(BaseModel):
50
+ response: str
51
+ model: str
52
+ max_length: int
53
+ temperature: float
54
+
55
+ class HealthResponse(BaseModel):
56
+ status: str
57
+ model_loaded: bool
58
+ device: str
59
+ model_path: str
60
+
61
+ @app.on_event("startup")
62
+ async def load_model():
63
+ """Load the model when the server starts"""
64
+ global model, tokenizer, device_info
65
+
66
+ model_path = os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH)
67
+
68
+ print(f"Loading model from: {model_path}")
69
+ print("=" * 70)
70
+
71
+ try:
72
+ device_info = get_device_info()
73
+ model, tokenizer = load_local_model(model_path)
74
+ print(f"\n✓ Model loaded successfully on {device_info['device']}!")
75
+ print(f"✓ Server ready to accept requests")
76
+ print("=" * 70)
77
+ except Exception as e:
78
+ print(f"\n✗ Error loading model: {e}")
79
+ print("=" * 70)
80
+ sys.exit(1)
81
+
82
+ @app.get("/health", response_model=HealthResponse)
83
+ async def health_check():
84
+ """Health check endpoint"""
85
+ return HealthResponse(
86
+ status="healthy" if model is not None else "error",
87
+ model_loaded=model is not None,
88
+ device=device_info["device"] if device_info else "unknown",
89
+ model_path=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH)
90
+ )
91
+
92
+ @app.get("/")
93
+ async def root():
94
+ """Root endpoint with API information"""
95
+ return {
96
+ "name": "Mistral 7B AHB2APB API",
97
+ "version": "1.0.0",
98
+ "status": "running",
99
+ "model": os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH),
100
+ "endpoints": {
101
+ "health": "/health",
102
+ "generate": "/api/generate",
103
+ "docs": "/docs"
104
+ }
105
+ }
106
+
107
+ @app.post("/api/generate", response_model=GenerateResponse)
108
+ async def generate(request: GenerateRequest):
109
+ """
110
+ Generate text from a prompt using the fine-tuned model
111
+ """
112
+ if model is None or tokenizer is None:
113
+ raise HTTPException(status_code=503, detail="Model not loaded")
114
+
115
+ try:
116
+ response = generate_with_local_model(
117
+ model=model,
118
+ tokenizer=tokenizer,
119
+ prompt=request.prompt,
120
+ max_length=request.max_length or 512,
121
+ temperature=request.temperature or 0.7
122
+ )
123
+
124
+ return GenerateResponse(
125
+ response=response,
126
+ model=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH),
127
+ max_length=request.max_length or 512,
128
+ temperature=request.temperature or 0.7
129
+ )
130
+ except Exception as e:
131
+ raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}")
132
+
133
+ @app.post("/api/generate/batch")
134
+ async def generate_batch(requests: list[GenerateRequest]):
135
+ """
136
+ Generate text from multiple prompts (batch processing)
137
+ """
138
+ if model is None or tokenizer is None:
139
+ raise HTTPException(status_code=503, detail="Model not loaded")
140
+
141
+ try:
142
+ responses = []
143
+ for req in requests:
144
+ response = generate_with_local_model(
145
+ model=model,
146
+ tokenizer=tokenizer,
147
+ prompt=req.prompt,
148
+ max_length=req.max_length or 512,
149
+ temperature=req.temperature or 0.7
150
+ )
151
+ responses.append({
152
+ "response": response,
153
+ "prompt": req.prompt
154
+ })
155
+
156
+ return {"results": responses}
157
+ except Exception as e:
158
+ raise HTTPException(status_code=500, detail=f"Batch generation error: {str(e)}")
159
+
160
+ if __name__ == "__main__":
161
+ import argparse
162
+
163
+ parser = argparse.ArgumentParser(description="Start Mistral 7B API server")
164
+ parser.add_argument(
165
+ "--model-path",
166
+ type=str,
167
+ default=DEFAULT_MODEL_PATH,
168
+ help=f"Path to fine-tuned model (default: {DEFAULT_MODEL_PATH})"
169
+ )
170
+ parser.add_argument(
171
+ "--host",
172
+ type=str,
173
+ default="0.0.0.0",
174
+ help="Host to bind to (default: 0.0.0.0)"
175
+ )
176
+ parser.add_argument(
177
+ "--port",
178
+ type=int,
179
+ default=8000,
180
+ help="Port to bind to (default: 8000)"
181
+ )
182
+ parser.add_argument(
183
+ "--reload",
184
+ action="store_true",
185
+ help="Enable auto-reload (for development)"
186
+ )
187
+ parser.add_argument(
188
+ "--workers",
189
+ type=int,
190
+ default=1,
191
+ help="Number of worker processes (default: 1)"
192
+ )
193
+
194
+ args = parser.parse_args()
195
+
196
+ # Set model path as environment variable for the startup event
197
+ os.environ["MODEL_PATH"] = args.model_path
198
+
199
+ print(f"\n🚀 Starting Mistral 7B AHB2APB API Server")
200
+ print(f" Model: {args.model_path}")
201
+ print(f" Host: {args.host}")
202
+ print(f" Port: {args.port}")
203
+ print(f" Workers: {args.workers}")
204
+ print(f" Reload: {args.reload}\n")
205
+
206
+ # Change to api directory for proper module resolution
207
+ import os
208
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
209
+ uvicorn.run(
210
+ "api_server:app",
211
+ host=args.host,
212
+ port=args.port,
213
+ reload=args.reload,
214
+ workers=1 if args.reload else args.workers
215
+ )
216
+