File size: 7,366 Bytes
7667556
d4f257f
 
 
 
7667556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4f257f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7667556
 
 
 
 
 
 
 
d4f257f
 
 
 
 
 
7667556
 
 
 
d4f257f
7667556
 
d4f257f
 
7667556
 
 
d4f257f
7667556
 
d4f257f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7667556
d4f257f
7667556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4f257f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7667556
 
 
 
 
 
d4f257f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7667556
 
 
 
 
 
 
 
 
 
d4f257f
7667556
 
d4f257f
7667556
d4f257f
 
 
7667556
 
d4f257f
 
 
7667556
 
 
 
 
d4f257f
7667556
d4f257f
 
 
7667556
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import os
import time
import threading
from collections import deque
from typing import Optional, List

import google.generativeai as genai
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# =========================
# Config
# =========================

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

if not GEMINI_API_KEY:
    raise RuntimeError("GEMINI_API_KEY is not set in environment variables.")

genai.configure(api_key=GEMINI_API_KEY)

# حط الموديلات بالترتيب اللي تفضله
MODEL_POOL = [
    "gemma-3-4b-it",
    "gemma-3-12b-it",
]

LOCAL_RPM_LIMIT_PER_MODEL = 30
WINDOW_SECONDS = 60

app = FastAPI(title="Gemma Intent API", version="1.0.0")

# =========================
# Simple in-memory rate tracker
# =========================

_request_history = {model: deque() for model in MODEL_POOL}
_request_lock = threading.Lock()


def _cleanup_old_requests(model_name: str, now_ts: float) -> None:
    q = _request_history[model_name]
    while q and now_ts - q[0] > WINDOW_SECONDS:
        q.popleft()


def get_model_request_count(model_name: str) -> int:
    now_ts = time.time()
    with _request_lock:
        _cleanup_old_requests(model_name, now_ts)
        return len(_request_history[model_name])


def register_model_request(model_name: str) -> int:
    now_ts = time.time()
    with _request_lock:
        _cleanup_old_requests(model_name, now_ts)
        _request_history[model_name].append(now_ts)
        return len(_request_history[model_name])


def pick_model() -> str:
    """
    اختار أول موديل لسه تحت الحد المحلي.
    لو كلهم فوق الحد، اختار الأقل استخدامًا في آخر دقيقة.
    """
    counts = []
    for model in MODEL_POOL:
        count = get_model_request_count(model)
        counts.append((model, count))

    # أول موديل تحت الحد
    for model, count in counts:
        if count < LOCAL_RPM_LIMIT_PER_MODEL:
            return model

    # لو كلهم فوق الحد: اختار الأقل استخدامًا
    counts.sort(key=lambda x: x[1])
    return counts[0][0]


def get_fallback_models(primary_model: str) -> List[str]:
    return [m for m in MODEL_POOL if m != primary_model]


# =========================
# Request / Response Models
# =========================

class ChatRequest(BaseModel):
    message: str
    system_prompt: Optional[str] = (
        "You are an intent classification assistant. "
        "Return a short direct answer only."
    )
    temperature: Optional[float] = 0.1
    max_output_tokens: Optional[int] = 80


class ChatResponse(BaseModel):
    success: bool
    model_used: str
    input_message: str
    reply: str
    requests_last_minute_for_model: int
    total_requests_last_minute_all_models: int


# =========================
# Helpers
# =========================

def total_requests_last_minute() -> int:
    return sum(get_model_request_count(model) for model in MODEL_POOL)


def build_prompt(system_prompt: str, user_message: str) -> str:
    return f"{system_prompt}\n\nUser: {user_message}\nAssistant:"


def is_rate_limit_error(exc: Exception) -> bool:
    msg = str(exc).lower()
    rate_limit_markers = [
        "429",
        "quota",
        "rate limit",
        "resource exhausted",
        "too many requests",
    ]
    return any(marker in msg for marker in rate_limit_markers)


def generate_with_model(
    model_name: str,
    prompt: str,
    temperature: float,
    max_output_tokens: int
) -> str:
    generation_config = genai.types.GenerationConfig(
        temperature=temperature,
        max_output_tokens=max_output_tokens,
        top_p=0.95,
    )

    model = genai.GenerativeModel(model_name)
    response = model.generate_content(
        prompt,
        generation_config=generation_config
    )

    try:
        return response.text.strip()
    except Exception:
        return "Model returned an empty response."


def generate_reply_with_fallback(
    user_message: str,
    system_prompt: str,
    temperature: float,
    max_output_tokens: int
):
    prompt = build_prompt(system_prompt, user_message)

    primary_model = pick_model()
    candidate_models = [primary_model] + get_fallback_models(primary_model)

    last_error = None

    for model_name in candidate_models:
        local_count_before = get_model_request_count(model_name)

        print(f"[INFO] Trying model: {model_name}")
        print(f"[INFO] Local requests in last minute for {model_name}: {local_count_before}")

        try:
            reply = generate_with_model(
                model_name=model_name,
                prompt=prompt,
                temperature=temperature,
                max_output_tokens=max_output_tokens,
            )

            used_count = register_model_request(model_name)
            return reply, model_name, used_count

        except Exception as e:
            last_error = e
            print(f"[WARN] Model failed: {model_name}")
            print(f"[WARN] Error: {str(e)}")

            # لو Rate Limit جرّب اللي بعده
            if is_rate_limit_error(e):
                continue

            # لو خطأ عادي برضه جرّب اللي بعده
            continue

    raise Exception(f"All models failed. Last error: {last_error}")


# =========================
# Routes
# =========================

@app.get("/")
def home():
    return {
        "status": "ok",
        "message": "Gemma Intent API is running",
        "models": MODEL_POOL,
        "local_rpm_limit_per_model": LOCAL_RPM_LIMIT_PER_MODEL
    }


@app.get("/stats")
def stats():
    return {
        "per_model_requests_last_minute": {
            model: get_model_request_count(model)
            for model in MODEL_POOL
        },
        "total_requests_last_minute": total_requests_last_minute()
    }


@app.post("/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
    if not req.message or not req.message.strip():
        raise HTTPException(status_code=400, detail="message is required")

    print("\n========== NEW REQUEST ==========")
    print("Incoming message:")
    print(req.message)
    print(f"Total requests last minute (all models): {total_requests_last_minute()}")

    try:
        reply, model_used, used_count = generate_reply_with_fallback(
            user_message=req.message,
            system_prompt=req.system_prompt or "You are a helpful assistant.",
            temperature=req.temperature if req.temperature is not None else 0.1,
            max_output_tokens=req.max_output_tokens if req.max_output_tokens is not None else 80,
        )

        print(f"Model used: {model_used}")
        print(f"Requests last minute for model after call: {used_count}")
        print("Model reply:")
        print(reply)
        print("=================================\n")

        return ChatResponse(
            success=True,
            model_used=model_used,
            input_message=req.message,
            reply=reply,
            requests_last_minute_for_model=used_count,
            total_requests_last_minute_all_models=total_requests_last_minute()
        )

    except Exception as e:
        print("\nERROR:")
        print(str(e))
        print("=================================\n")
        raise HTTPException(status_code=500, detail=str(e))