Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import threading | |
| from collections import deque | |
| from typing import Optional, List | |
| import google.generativeai as genai | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| # ========================= | |
| # Config | |
| # ========================= | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| raise RuntimeError("GEMINI_API_KEY is not set in environment variables.") | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| # حط الموديلات بالترتيب اللي تفضله | |
| MODEL_POOL = [ | |
| "gemma-3-4b-it", | |
| "gemma-3-12b-it", | |
| ] | |
| LOCAL_RPM_LIMIT_PER_MODEL = 30 | |
| WINDOW_SECONDS = 60 | |
| app = FastAPI(title="Gemma Intent API", version="1.0.0") | |
| # ========================= | |
| # Simple in-memory rate tracker | |
| # ========================= | |
| _request_history = {model: deque() for model in MODEL_POOL} | |
| _request_lock = threading.Lock() | |
| def _cleanup_old_requests(model_name: str, now_ts: float) -> None: | |
| q = _request_history[model_name] | |
| while q and now_ts - q[0] > WINDOW_SECONDS: | |
| q.popleft() | |
| def get_model_request_count(model_name: str) -> int: | |
| now_ts = time.time() | |
| with _request_lock: | |
| _cleanup_old_requests(model_name, now_ts) | |
| return len(_request_history[model_name]) | |
| def register_model_request(model_name: str) -> int: | |
| now_ts = time.time() | |
| with _request_lock: | |
| _cleanup_old_requests(model_name, now_ts) | |
| _request_history[model_name].append(now_ts) | |
| return len(_request_history[model_name]) | |
| def pick_model() -> str: | |
| """ | |
| اختار أول موديل لسه تحت الحد المحلي. | |
| لو كلهم فوق الحد، اختار الأقل استخدامًا في آخر دقيقة. | |
| """ | |
| counts = [] | |
| for model in MODEL_POOL: | |
| count = get_model_request_count(model) | |
| counts.append((model, count)) | |
| # أول موديل تحت الحد | |
| for model, count in counts: | |
| if count < LOCAL_RPM_LIMIT_PER_MODEL: | |
| return model | |
| # لو كلهم فوق الحد: اختار الأقل استخدامًا | |
| counts.sort(key=lambda x: x[1]) | |
| return counts[0][0] | |
| def get_fallback_models(primary_model: str) -> List[str]: | |
| return [m for m in MODEL_POOL if m != primary_model] | |
| # ========================= | |
| # Request / Response Models | |
| # ========================= | |
| class ChatRequest(BaseModel): | |
| message: str | |
| system_prompt: Optional[str] = ( | |
| "You are an intent classification assistant. " | |
| "Return a short direct answer only." | |
| ) | |
| temperature: Optional[float] = 0.1 | |
| max_output_tokens: Optional[int] = 80 | |
| class ChatResponse(BaseModel): | |
| success: bool | |
| model_used: str | |
| input_message: str | |
| reply: str | |
| requests_last_minute_for_model: int | |
| total_requests_last_minute_all_models: int | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def total_requests_last_minute() -> int: | |
| return sum(get_model_request_count(model) for model in MODEL_POOL) | |
| def build_prompt(system_prompt: str, user_message: str) -> str: | |
| return f"{system_prompt}\n\nUser: {user_message}\nAssistant:" | |
| def is_rate_limit_error(exc: Exception) -> bool: | |
| msg = str(exc).lower() | |
| rate_limit_markers = [ | |
| "429", | |
| "quota", | |
| "rate limit", | |
| "resource exhausted", | |
| "too many requests", | |
| ] | |
| return any(marker in msg for marker in rate_limit_markers) | |
| def generate_with_model( | |
| model_name: str, | |
| prompt: str, | |
| temperature: float, | |
| max_output_tokens: int | |
| ) -> str: | |
| generation_config = genai.types.GenerationConfig( | |
| temperature=temperature, | |
| max_output_tokens=max_output_tokens, | |
| top_p=0.95, | |
| ) | |
| model = genai.GenerativeModel(model_name) | |
| response = model.generate_content( | |
| prompt, | |
| generation_config=generation_config | |
| ) | |
| try: | |
| return response.text.strip() | |
| except Exception: | |
| return "Model returned an empty response." | |
| def generate_reply_with_fallback( | |
| user_message: str, | |
| system_prompt: str, | |
| temperature: float, | |
| max_output_tokens: int | |
| ): | |
| prompt = build_prompt(system_prompt, user_message) | |
| primary_model = pick_model() | |
| candidate_models = [primary_model] + get_fallback_models(primary_model) | |
| last_error = None | |
| for model_name in candidate_models: | |
| local_count_before = get_model_request_count(model_name) | |
| print(f"[INFO] Trying model: {model_name}") | |
| print(f"[INFO] Local requests in last minute for {model_name}: {local_count_before}") | |
| try: | |
| reply = generate_with_model( | |
| model_name=model_name, | |
| prompt=prompt, | |
| temperature=temperature, | |
| max_output_tokens=max_output_tokens, | |
| ) | |
| used_count = register_model_request(model_name) | |
| return reply, model_name, used_count | |
| except Exception as e: | |
| last_error = e | |
| print(f"[WARN] Model failed: {model_name}") | |
| print(f"[WARN] Error: {str(e)}") | |
| # لو Rate Limit جرّب اللي بعده | |
| if is_rate_limit_error(e): | |
| continue | |
| # لو خطأ عادي برضه جرّب اللي بعده | |
| continue | |
| raise Exception(f"All models failed. Last error: {last_error}") | |
| # ========================= | |
| # Routes | |
| # ========================= | |
| def home(): | |
| return { | |
| "status": "ok", | |
| "message": "Gemma Intent API is running", | |
| "models": MODEL_POOL, | |
| "local_rpm_limit_per_model": LOCAL_RPM_LIMIT_PER_MODEL | |
| } | |
| def stats(): | |
| return { | |
| "per_model_requests_last_minute": { | |
| model: get_model_request_count(model) | |
| for model in MODEL_POOL | |
| }, | |
| "total_requests_last_minute": total_requests_last_minute() | |
| } | |
| def chat(req: ChatRequest): | |
| if not req.message or not req.message.strip(): | |
| raise HTTPException(status_code=400, detail="message is required") | |
| print("\n========== NEW REQUEST ==========") | |
| print("Incoming message:") | |
| print(req.message) | |
| print(f"Total requests last minute (all models): {total_requests_last_minute()}") | |
| try: | |
| reply, model_used, used_count = generate_reply_with_fallback( | |
| user_message=req.message, | |
| system_prompt=req.system_prompt or "You are a helpful assistant.", | |
| temperature=req.temperature if req.temperature is not None else 0.1, | |
| max_output_tokens=req.max_output_tokens if req.max_output_tokens is not None else 80, | |
| ) | |
| print(f"Model used: {model_used}") | |
| print(f"Requests last minute for model after call: {used_count}") | |
| print("Model reply:") | |
| print(reply) | |
| print("=================================\n") | |
| return ChatResponse( | |
| success=True, | |
| model_used=model_used, | |
| input_message=req.message, | |
| reply=reply, | |
| requests_last_minute_for_model=used_count, | |
| total_requests_last_minute_all_models=total_requests_last_minute() | |
| ) | |
| except Exception as e: | |
| print("\nERROR:") | |
| print(str(e)) | |
| print("=================================\n") | |
| raise HTTPException(status_code=500, detail=str(e)) |