File size: 5,585 Bytes
a3bb57e
 
 
 
 
 
 
 
 
 
 
 
 
760a824
a3bb57e
760a824
a3bb57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760a824
 
 
 
 
 
 
a3bb57e
 
760a824
a3bb57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import re
from contextlib import asynccontextmanager
from typing import List, Optional

import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

# =========================
# Config
# =========================
MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-4b-it")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12"))
HF_TOKEN = os.getenv("HF_TOKEN")

# لو عايز تغير الانتنـتس من غير تعديل الكود:
# مثال:
# INTENTS="greeting,pricing,complaint,booking,follow_up,other"
INTENTS_ENV = os.getenv(
    "INTENTS",
    "same_path,change_path,greeting,pricing,booking,complaint,follow_up,other"
)
ALLOWED_INTENTS = [x.strip() for x in INTENTS_ENV.split(",") if x.strip()]

model = None
processor = None


# =========================
# Schemas
# =========================
class IntentRequest(BaseModel):
    message: str
    intents: Optional[List[str]] = None
    system_prompt: Optional[str] = None


class IntentResponse(BaseModel):
    intent: str
    raw_output: str
    model: str


# =========================
# Helpers
# =========================
def normalize_intent(text: str, allowed_intents: List[str]) -> str:
    cleaned = text.strip().lower()

    # شيل أي markdown/code fences أو علامات زيادة
    cleaned = cleaned.replace("```", "").replace("`", "").strip()

    # لو الموديل رجّع جملة فيها intent ضمن النص
    for intent in allowed_intents:
        if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned):
            return intent

    # fallback
    return "other"


def build_prompt(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> List[dict]:
    intent_list = ", ".join(allowed_intents)

    system_text = custom_system_prompt or (
        "You are an intent classifier.\n"
        f"Choose exactly one intent from this list: {intent_list}.\n"
        "Return only the intent label, with no explanation, no punctuation, and no extra words."
    )

    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_text}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": user_message}]
        }
    ]


def run_intent_classification(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> tuple[str, str]:
    global model, processor

    messages = build_prompt(user_message, allowed_intents, custom_system_prompt)

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )

    # CPU inference
    with torch.inference_mode():
        generation = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=None,
            top_p=None,
        )

    input_len = inputs["input_ids"].shape[-1]
    generated_tokens = generation[0][input_len:]
    decoded = processor.decode(generated_tokens, skip_special_tokens=True).strip()

    final_intent = normalize_intent(decoded, allowed_intents)
    return final_intent, decoded


# =========================
# Lifespan
# =========================
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, processor

    print(f"[startup] Loading model: {MODEL_ID}")

    if not HF_TOKEN:
        raise RuntimeError("HF_TOKEN is missing. Add it in Hugging Face Space Secrets.")

    processor = AutoProcessor.from_pretrained(
        MODEL_ID,
        token=HF_TOKEN
    )
    model = Gemma3ForConditionalGeneration.from_pretrained(
        MODEL_ID,
        token=HF_TOKEN,
        device_map="cpu"
    ).eval()

    print("[startup] Model loaded successfully.")
    yield
    print("[shutdown] App is shutting down.")


app = FastAPI(
    title="Gemma Intent Classifier API",
    version="1.0.0",
    lifespan=lifespan
)


# =========================
# Routes
# =========================
@app.get("/")
def root():
    return {
        "status": "ok",
        "message": "Gemma Intent Classifier API is running."
    }


@app.get("/health")
def health():
    return {
        "status": "healthy",
        "model": MODEL_ID
    }


@app.post("/intent", response_model=IntentResponse)
def classify_intent(payload: IntentRequest):
    if not payload.message or not payload.message.strip():
        raise HTTPException(status_code=400, detail="message is required")

    allowed_intents = payload.intents if payload.intents else ALLOWED_INTENTS

    if not allowed_intents:
        raise HTTPException(status_code=400, detail="No intents provided")

    try:
        intent, raw_output = run_intent_classification(
            user_message=payload.message.strip(),
            allowed_intents=allowed_intents,
            custom_system_prompt=payload.system_prompt
        )

        print("========== REQUEST ==========")
        print(f"message: {payload.message}")
        print(f"allowed_intents: {allowed_intents}")
        print("========== RESPONSE =========")
        print(f"raw_output: {raw_output}")
        print(f"intent: {intent}")
        print("================================")

        return IntentResponse(
            intent=intent,
            raw_output=raw_output,
            model=MODEL_ID
        )

    except Exception as e:
        print(f"[error] {repr(e)}")
        raise HTTPException(status_code=500, detail=str(e))