File size: 7,938 Bytes
e5ba726
 
 
 
 
 
 
 
8ecbd6b
e5ba726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ecbd6b
e5ba726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ecbd6b
e5ba726
8ecbd6b
e5ba726
 
8ecbd6b
 
 
 
e5ba726
 
8ecbd6b
 
e5ba726
8ecbd6b
e5ba726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ecbd6b
 
 
 
 
e5ba726
 
 
 
 
 
8ecbd6b
 
 
e5ba726
 
 
 
8ecbd6b
e5ba726
 
 
 
 
 
 
8ecbd6b
e5ba726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ecbd6b
e5ba726
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import time
import logging
import asyncio
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor

# -------------------------
# Configuration (via env)
# -------------------------
REPO_ID = os.getenv("REPO_ID", "unsloth/gemma-3-270m-it-GGUF")
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2"))          # ThreadPool workers (reduced for speed)
MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", "1"))  # Reduced for speed
RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "60"))
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*")
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120"))
# llama-cpp-python specific settings
N_CTX = int(os.getenv("N_CTX", "2048"))  # Context window
N_THREADS = int(os.getenv("N_THREADS", "4"))  # CPU threads
N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0"))  # GPU layers (0 for CPU only)

# -------------------------
# Logging
# -------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("gemma_api")

# -------------------------
# FastAPI app
# -------------------------
app = FastAPI(title="Gemma 3 270M ThreadPool API")

origins = ["*"] if ALLOWED_ORIGINS=="*" else ALLOWED_ORIGINS.split(",")
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_methods=["*"],
    allow_headers=["*"],
)

# -------------------------
# Request / Response Models
# -------------------------
class Message(BaseModel):
    role: str
    content: str

class GenerationRequest(BaseModel):
    messages: Optional[List[Message]] = None
    prompt: Optional[str] = None
    max_new_tokens: int = Field(50, ge=1, le=500)  # Reduced for faster response
    temperature: float = Field(0.7, ge=0.0, le=2.0)
    top_p: float = Field(0.9, ge=0.0, le=1.0)
    do_sample: bool = Field(True)
    # Speed optimization parameters
    num_beams: int = Field(1, ge=1, le=4)  # Greedy decoding by default
    early_stopping: bool = Field(True)
    use_cache: bool = Field(True)

class GenerationResponse(BaseModel):
    generated_text: str
    model: str
    runtime_seconds: float

# -------------------------
# Global objects
# -------------------------
LLM_MODEL: Optional[Any] = None
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
model_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)

# -------------------------
# Rate limiting (simple token-bucket per IP)
# -------------------------
class RateLimiter:
    def __init__(self, per_minute: int):
        self.per_minute = per_minute
        self.storage: Dict[str, Dict[str, Any]] = {}
        self.lock = asyncio.Lock()

    async def allow(self, key: str) -> bool:
        now = time.time()
        async with self.lock:
            rec = self.storage.get(key)
            if not rec:
                self.storage[key] = {"tokens": self.per_minute - 1, "ts": now}
                return True
            elapsed = now - rec["ts"]
            refill = (elapsed / 60.0) * self.per_minute
            rec["tokens"] = min(self.per_minute, rec["tokens"] + refill)
            rec["ts"] = now
            if rec["tokens"] >= 1:
                rec["tokens"] -= 1
                return True
            return False

rate_limiter = RateLimiter(RATE_LIMIT_PER_MIN)

# -------------------------
# Utility functions
# -------------------------

# build_prompt_from_messages function removed - using chat completion format directly

def generate_sync(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str:
    # transformers pipeline generation parameters
    generation_kwargs = {
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": do_sample,
        "num_beams": num_beams,
        "early_stopping": early_stopping,
        "use_cache": use_cache,
    }
    
    # Generate using transformers pipeline
    response = LLM_MODEL(messages, **generation_kwargs)
    
    return response[0]["generated_text"][-1]["content"] if isinstance(response[0]["generated_text"], list) else response[0]["generated_text"]

async def generate_async(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str:
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(
        executor,
        lambda: generate_sync(messages, max_new_tokens, temperature, top_p, do_sample, num_beams, early_stopping, use_cache)
    )

# -------------------------
# Startup
# -------------------------
@app.on_event("startup")
async def on_startup():
    global LLM_MODEL
    
    try:
        logger.info(f"Loading model from {REPO_ID}...")
        LLM_MODEL = pipeline(
            "text-generation",
            model=REPO_ID,
            device_map="auto" if N_GPU_LAYERS > 0 else "cpu"
        )
        logger.info("Model loaded successfully.")
        
        # Warm up the model with a dummy request for faster first inference
        logger.info("Warming up model...")
        dummy_messages = [{"role": "user", "content": "Hello"}]
        _ = LLM_MODEL(
            dummy_messages,
            max_new_tokens=5,
            temperature=0.1
        )
        logger.info("Model warmed up successfully.")
    except Exception as e:
        logger.error(f"Failed to load model {REPO_ID}: {e}")
        raise RuntimeError(f"Model loading failed: {e}") from e

# -------------------------
# Endpoints
# -------------------------
@app.get("/")
async def root():
    return {"status": "Gemma 3 API is running 🎉", "model": REPO_ID}

@app.get("/health")
async def health():
    return {"status": "ok", "model_loaded": LLM_MODEL is not None}

@app.get("/metrics")
async def metrics():
    return {
        "model": REPO_ID,
        "max_concurrent_requests": MAX_CONCURRENT_REQUESTS,
        "current_semaphore_locked": model_semaphore._value if hasattr(model_semaphore, "_value") else None,
        "threadpool_workers": MAX_WORKERS
    }

@app.post("/generate", response_model=GenerationResponse)
async def generate(req: GenerationRequest, request: Request):
    client_ip = request.client.host if request.client else "unknown"
    allowed = await rate_limiter.allow(client_ip)
    if not allowed:
        raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded")

    # Convert to chat messages format for llama-cpp-python
    if req.messages:
        chat_messages = [{"role": msg.role, "content": msg.content} for msg in req.messages]
    elif req.prompt:
        chat_messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": req.prompt}
        ]
    else:
        raise HTTPException(status_code=400, detail="Provide either 'messages' or 'prompt'.")

    start = time.time()
    try:
        async with model_semaphore:
            generated_text = await generate_async(
                chat_messages,
                max_new_tokens=req.max_new_tokens,
                temperature=req.temperature,
                top_p=req.top_p,
                do_sample=req.do_sample,
                num_beams=req.num_beams,
                early_stopping=req.early_stopping,
                use_cache=req.use_cache
            )
    except asyncio.TimeoutError:
        raise HTTPException(status_code=504, detail="Generation timed out or concurrency queue full")

    runtime = time.time() - start

    return GenerationResponse(
        generated_text=generated_text,
        model=REPO_ID,
        runtime_seconds=round(runtime, 3)
    )