import os import re from contextlib import asynccontextmanager from typing import List, Optional import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoProcessor, Gemma3ForConditionalGeneration # ========================= # Config # ========================= MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-4b-it") MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12")) HF_TOKEN = os.getenv("HF_TOKEN") # لو عايز تغير الانتنـتس من غير تعديل الكود: # مثال: # INTENTS="greeting,pricing,complaint,booking,follow_up,other" INTENTS_ENV = os.getenv( "INTENTS", "same_path,change_path,greeting,pricing,booking,complaint,follow_up,other" ) ALLOWED_INTENTS = [x.strip() for x in INTENTS_ENV.split(",") if x.strip()] model = None processor = None # ========================= # Schemas # ========================= class IntentRequest(BaseModel): message: str intents: Optional[List[str]] = None system_prompt: Optional[str] = None class IntentResponse(BaseModel): intent: str raw_output: str model: str # ========================= # Helpers # ========================= def normalize_intent(text: str, allowed_intents: List[str]) -> str: cleaned = text.strip().lower() # شيل أي markdown/code fences أو علامات زيادة cleaned = cleaned.replace("```", "").replace("`", "").strip() # لو الموديل رجّع جملة فيها intent ضمن النص for intent in allowed_intents: if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned): return intent # fallback return "other" def build_prompt(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> List[dict]: intent_list = ", ".join(allowed_intents) system_text = custom_system_prompt or ( "You are an intent classifier.\n" f"Choose exactly one intent from this list: {intent_list}.\n" "Return only the intent label, with no explanation, no punctuation, and no extra words." ) return [ { "role": "system", "content": [{"type": "text", "text": system_text}] }, { "role": "user", "content": [{"type": "text", "text": user_message}] } ] def run_intent_classification(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> tuple[str, str]: global model, processor messages = build_prompt(user_message, allowed_intents, custom_system_prompt) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) # CPU inference with torch.inference_mode(): generation = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, temperature=None, top_p=None, ) input_len = inputs["input_ids"].shape[-1] generated_tokens = generation[0][input_len:] decoded = processor.decode(generated_tokens, skip_special_tokens=True).strip() final_intent = normalize_intent(decoded, allowed_intents) return final_intent, decoded # ========================= # Lifespan # ========================= @asynccontextmanager async def lifespan(app: FastAPI): global model, processor print(f"[startup] Loading model: {MODEL_ID}") if not HF_TOKEN: raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.") processor = AutoProcessor.from_pretrained( MODEL_ID, token=HF_TOKEN ) model = Gemma3ForConditionalGeneration.from_pretrained( MODEL_ID, token=HF_TOKEN, device_map="cpu" ).eval() print("[startup] Model loaded successfully.") yield print("[shutdown] App is shutting down.") app = FastAPI( title="Gemma Intent Classifier API", version="1.0.0", lifespan=lifespan ) # ========================= # Routes # ========================= @app.get("/") def root(): return { "status": "ok", "message": "Gemma Intent Classifier API is running." } @app.get("/health") def health(): return { "status": "healthy", "model": MODEL_ID } @app.post("/intent", response_model=IntentResponse) def classify_intent(payload: IntentRequest): if not payload.message or not payload.message.strip(): raise HTTPException(status_code=400, detail="message is required") allowed_intents = payload.intents if payload.intents else ALLOWED_INTENTS if not allowed_intents: raise HTTPException(status_code=400, detail="No intents provided") try: intent, raw_output = run_intent_classification( user_message=payload.message.strip(), allowed_intents=allowed_intents, custom_system_prompt=payload.system_prompt ) print("========== REQUEST ==========") print(f"message: {payload.message}") print(f"allowed_intents: {allowed_intents}") print("========== RESPONSE =========") print(f"raw_output: {raw_output}") print(f"intent: {intent}") print("================================") return IntentResponse( intent=intent, raw_output=raw_output, model=MODEL_ID ) except Exception as e: print(f"[error] {repr(e)}") raise HTTPException(status_code=500, detail=str(e))