Mr-Help commited on
Commit
a3bb57e
·
verified ·
1 Parent(s): fc3dca3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from contextlib import asynccontextmanager
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
10
+
11
+ # =========================
12
+ # Config
13
+ # =========================
14
+ MODEL_ID = os.getenv("MODEL_ID", "google/gemma-3-1b-it")
15
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "12"))
16
+
17
+ # لو عايز تغير الانتنـتس من غير تعديل الكود:
18
+ # مثال:
19
+ # INTENTS="greeting,pricing,complaint,booking,follow_up,other"
20
+ INTENTS_ENV = os.getenv(
21
+ "INTENTS",
22
+ "same_path,change_path,greeting,pricing,booking,complaint,follow_up,other"
23
+ )
24
+ ALLOWED_INTENTS = [x.strip() for x in INTENTS_ENV.split(",") if x.strip()]
25
+
26
+ model = None
27
+ processor = None
28
+
29
+
30
+ # =========================
31
+ # Schemas
32
+ # =========================
33
+ class IntentRequest(BaseModel):
34
+ message: str
35
+ intents: Optional[List[str]] = None
36
+ system_prompt: Optional[str] = None
37
+
38
+
39
+ class IntentResponse(BaseModel):
40
+ intent: str
41
+ raw_output: str
42
+ model: str
43
+
44
+
45
+ # =========================
46
+ # Helpers
47
+ # =========================
48
+ def normalize_intent(text: str, allowed_intents: List[str]) -> str:
49
+ cleaned = text.strip().lower()
50
+
51
+ # شيل أي markdown/code fences أو علامات زيادة
52
+ cleaned = cleaned.replace("```", "").replace("`", "").strip()
53
+
54
+ # لو الموديل رجّع جملة فيها intent ضمن النص
55
+ for intent in allowed_intents:
56
+ if re.search(rf"\b{re.escape(intent.lower())}\b", cleaned):
57
+ return intent
58
+
59
+ # fallback
60
+ return "other"
61
+
62
+
63
+ def build_prompt(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> List[dict]:
64
+ intent_list = ", ".join(allowed_intents)
65
+
66
+ system_text = custom_system_prompt or (
67
+ "You are an intent classifier.\n"
68
+ f"Choose exactly one intent from this list: {intent_list}.\n"
69
+ "Return only the intent label, with no explanation, no punctuation, and no extra words."
70
+ )
71
+
72
+ return [
73
+ {
74
+ "role": "system",
75
+ "content": [{"type": "text", "text": system_text}]
76
+ },
77
+ {
78
+ "role": "user",
79
+ "content": [{"type": "text", "text": user_message}]
80
+ }
81
+ ]
82
+
83
+
84
+ def run_intent_classification(user_message: str, allowed_intents: List[str], custom_system_prompt: Optional[str]) -> tuple[str, str]:
85
+ global model, processor
86
+
87
+ messages = build_prompt(user_message, allowed_intents, custom_system_prompt)
88
+
89
+ inputs = processor.apply_chat_template(
90
+ messages,
91
+ add_generation_prompt=True,
92
+ tokenize=True,
93
+ return_dict=True,
94
+ return_tensors="pt",
95
+ )
96
+
97
+ # CPU inference
98
+ with torch.inference_mode():
99
+ generation = model.generate(
100
+ **inputs,
101
+ max_new_tokens=MAX_NEW_TOKENS,
102
+ do_sample=False,
103
+ temperature=None,
104
+ top_p=None,
105
+ )
106
+
107
+ input_len = inputs["input_ids"].shape[-1]
108
+ generated_tokens = generation[0][input_len:]
109
+ decoded = processor.decode(generated_tokens, skip_special_tokens=True).strip()
110
+
111
+ final_intent = normalize_intent(decoded, allowed_intents)
112
+ return final_intent, decoded
113
+
114
+
115
+ # =========================
116
+ # Lifespan
117
+ # =========================
118
+ @asynccontextmanager
119
+ async def lifespan(app: FastAPI):
120
+ global model, processor
121
+
122
+ print(f"[startup] Loading model: {MODEL_ID}")
123
+
124
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
125
+ model = Gemma3ForConditionalGeneration.from_pretrained(
126
+ MODEL_ID,
127
+ device_map="cpu"
128
+ ).eval()
129
+
130
+ print("[startup] Model loaded successfully.")
131
+ yield
132
+ print("[shutdown] App is shutting down.")
133
+
134
+
135
+ app = FastAPI(
136
+ title="Gemma Intent Classifier API",
137
+ version="1.0.0",
138
+ lifespan=lifespan
139
+ )
140
+
141
+
142
+ # =========================
143
+ # Routes
144
+ # =========================
145
+ @app.get("/")
146
+ def root():
147
+ return {
148
+ "status": "ok",
149
+ "message": "Gemma Intent Classifier API is running."
150
+ }
151
+
152
+
153
+ @app.get("/health")
154
+ def health():
155
+ return {
156
+ "status": "healthy",
157
+ "model": MODEL_ID
158
+ }
159
+
160
+
161
+ @app.post("/intent", response_model=IntentResponse)
162
+ def classify_intent(payload: IntentRequest):
163
+ if not payload.message or not payload.message.strip():
164
+ raise HTTPException(status_code=400, detail="message is required")
165
+
166
+ allowed_intents = payload.intents if payload.intents else ALLOWED_INTENTS
167
+
168
+ if not allowed_intents:
169
+ raise HTTPException(status_code=400, detail="No intents provided")
170
+
171
+ try:
172
+ intent, raw_output = run_intent_classification(
173
+ user_message=payload.message.strip(),
174
+ allowed_intents=allowed_intents,
175
+ custom_system_prompt=payload.system_prompt
176
+ )
177
+
178
+ print("========== REQUEST ==========")
179
+ print(f"message: {payload.message}")
180
+ print(f"allowed_intents: {allowed_intents}")
181
+ print("========== RESPONSE =========")
182
+ print(f"raw_output: {raw_output}")
183
+ print(f"intent: {intent}")
184
+ print("================================")
185
+
186
+ return IntentResponse(
187
+ intent=intent,
188
+ raw_output=raw_output,
189
+ model=MODEL_ID
190
+ )
191
+
192
+ except Exception as e:
193
+ print(f"[error] {repr(e)}")
194
+ raise HTTPException(status_code=500, detail=str(e))