mila2030 commited on
Commit
e23fa9f
·
verified ·
1 Parent(s): fe52361

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +118 -53
handler.py CHANGED
@@ -1,87 +1,152 @@
1
- # handler.py — warm-up + retry, always returns text or clear debug
 
 
 
 
 
2
 
3
- from typing import Any, Dict
4
- import os, time, socket, requests
5
  import google.generativeai as genai
6
 
7
- MODEL = os.getenv("GEMINI_MODEL", "gemini-1.5-flash")
8
- TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
9
- TOP_P = float(os.getenv("TOP_P", "0.95"))
10
- MAX_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "512"))
 
 
11
  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")
12
 
 
 
 
 
 
 
13
  def _extract_text(resp: Any) -> str:
14
- if getattr(resp, "text", None): return resp.text
 
 
15
  try:
16
  for c in getattr(resp, "candidates", []) or []:
17
  content = getattr(c, "content", None)
18
  for p in getattr(content, "parts", []) or []:
19
  t = getattr(p, "text", None)
20
- if t: return t
 
21
  except Exception:
22
  pass
23
  return ""
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class EndpointHandler:
26
  def __init__(self, path: str = ""):
27
- self.api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
28
- if not self.api_key:
29
- self.init_error = "Missing GEMINI_API_KEY in Endpoint → Settings → Environment Variables."
30
- print("[handler] INIT ERROR:", self.init_error, flush=True)
31
  return
32
- self.init_error = None
33
 
34
- # Configure SDK
35
- genai.configure(api_key=self.api_key)
36
- self.model = genai.GenerativeModel(MODEL, system_instruction=SYSTEM_PROMPT)
37
-
38
- # Warm-up: a tiny prompt so first real call isn’t cold
39
  try:
40
  t0 = time.time()
41
- resp = self.model.generate_content("ping", generation_config={"max_output_tokens": 4})
42
- print("[handler] warmup ok in", round((time.time()-t0)*1000), "ms; text=",
43
- bool(_extract_text(resp)), flush=True)
44
  except Exception as e:
45
- # Non-fatal; we’ll still serve but log it
46
- print("[handler] warmup failed:", repr(e), flush=True)
47
 
48
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
49
- if self.init_error:
50
- return {"text":"", "debug":{"where":"init","error":self.init_error}}
51
-
52
- # Normalize input
53
- inputs = data.get("inputs")
54
- if isinstance(inputs, str): user_text = inputs.strip()
55
- elif isinstance(inputs, dict) and "messages" in inputs:
56
- user_text = ""
57
- for m in reversed(inputs["messages"] or []):
58
- if (m.get("role") or "user").lower() == "user":
59
- user_text = str(m.get("content", "")).strip(); break
60
- else: user_text = ""
61
 
 
 
62
  if not user_text:
63
- return {"text":"", "debug":{"where":"handler","note":"Empty prompt"}}
64
 
65
- gen_cfg = {"temperature":TEMPERATURE, "top_p":TOP_P, "max_output_tokens":MAX_TOKENS}
 
 
66
 
67
- # Retry with backoff (e.g., transient network, cold start)
68
- last_exc = None
69
- for attempt in range(3):
70
  try:
71
- t0 = time.time()
72
- resp = self.model.generate_content(user_text, generation_config=gen_cfg)
73
- dt = round((time.time()-t0)*1000)
 
 
 
74
  txt = _extract_text(resp)
75
- if txt: return {"text": txt, "debug":{"latency_ms": dt, "attempt": attempt+1}}
76
- # no text -> include finish reasons if any
 
 
77
  try:
78
- fr = [getattr(c, "finish_reason", None) for c in (resp.candidates or [])]
79
  except Exception:
80
- fr = []
81
- return {"text":"", "debug":{"where":"gemini_empty","finish_reasons":fr,"latency_ms":dt}}
82
- # network/transient
83
  except Exception as e:
84
- last_exc = repr(e)
85
- time.sleep(0.4*(attempt+1)) # backoff
86
 
87
- return {"text":"", "debug":{"where":"gemini_exception","exception":last_exc}}
 
1
+ # handler.py — Hugging Face custom handler compatible with your local Flask shapes
2
+ # - Env: GEMINI_API_KEY (or GOOGLE_API_KEY), GEMINI_MODEL, TEMPERATURE, TOP_P, MAX_OUTPUT_TOKENS, SYSTEM_PROMPT
3
+ # - Input shapes supported:
4
+ # 1) {"inputs": "hello", "parameters": {...}}
5
+ # 2) {"message": "hello", "history": [...], "parameters": {...}} # your Flask shape
6
+ # 3) {"inputs": {"messages":[{"role":"user","content":"hello"}]}, "parameters": {...}}
7
 
8
+ from typing import Any, Dict, List, Optional
9
+ import os, time
10
  import google.generativeai as genai
11
 
12
+ # ---------- Config (env) ----------
13
+ API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
14
+ MODEL_NAME = os.getenv("GEMINI_MODEL", "models/gemini-2.5-pro") # keep your local default
15
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
16
+ TOP_P = float(os.getenv("TOP_P", "0.95"))
17
+ MAX_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "1024"))
18
  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")
19
 
20
+ GEN_CFG = {
21
+ "temperature": TEMPERATURE,
22
+ "top_p": TOP_P,
23
+ "max_output_tokens": MAX_TOKENS,
24
+ }
25
+
26
  def _extract_text(resp: Any) -> str:
27
+ # google.generativeai responses usually expose .text; keep a fallback
28
+ if getattr(resp, "text", None):
29
+ return resp.text
30
  try:
31
  for c in getattr(resp, "candidates", []) or []:
32
  content = getattr(c, "content", None)
33
  for p in getattr(content, "parts", []) or []:
34
  t = getattr(p, "text", None)
35
+ if t:
36
+ return t
37
  except Exception:
38
  pass
39
  return ""
40
 
41
+ def _to_gemini_history(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
+ """
43
+ Accepts your local history shape:
44
+ [{ "role":"user"|"model", "parts":[{"text":"..."}] }, ...]
45
+ …or an OpenAI-ish/messages list:
46
+ [{ "role":"user"|"assistant", "content":"..."}, ...]
47
+ Produces Gemini-compatible history.
48
+ """
49
+ out: List[Dict[str, Any]] = []
50
+ for m in history or []:
51
+ role = (m.get("role") or "").lower()
52
+ if "parts" in m: # already Gemini-ish
53
+ parts = m.get("parts") or []
54
+ out.append({"role": role if role in ("user","model") else "user", "parts": parts})
55
+ else:
56
+ content = m.get("content", "")
57
+ if not isinstance(content, str):
58
+ try:
59
+ content = str(content)
60
+ except Exception:
61
+ content = ""
62
+ # map "assistant" -> "model"
63
+ role = "model" if role == "assistant" else ("user" if role != "model" else "model")
64
+ out.append({"role": role, "parts": [{"text": content}]})
65
+ return out
66
+
67
+ def _pick_user_text(payload: Dict[str, Any]) -> Optional[str]:
68
+ # 1) {"inputs": "text"}
69
+ if isinstance(payload.get("inputs"), str):
70
+ return payload["inputs"].strip()
71
+
72
+ # 2) {"message":"text"}
73
+ msg = payload.get("message")
74
+ if isinstance(msg, str) and msg.strip():
75
+ return msg.strip()
76
+
77
+ # 3) {"inputs":{"messages":[...]}} pick last user message
78
+ x = payload.get("inputs")
79
+ if isinstance(x, dict):
80
+ msgs = x.get("messages") or []
81
+ for m in reversed(msgs):
82
+ if (m.get("role","user").lower()) == "user":
83
+ c = m.get("content","")
84
+ return c if isinstance(c,str) else str(c)
85
+
86
+ return None
87
+
88
  class EndpointHandler:
89
  def __init__(self, path: str = ""):
90
+ if not API_KEY:
91
+ self._init_error = "Missing GEMINI_API_KEY/GOOGLE_API_KEY in Environment Variables."
92
+ print("[handler:init] ERROR:", self._init_error, flush=True)
 
93
  return
 
94
 
95
+ self._init_error = None
96
+ genai.configure(api_key=API_KEY)
97
+ self.model = genai.GenerativeModel(MODEL_NAME, system_instruction=SYSTEM_PROMPT)
98
+ # Warm-up (non-fatal if it fails)
 
99
  try:
100
  t0 = time.time()
101
+ _ = self.model.generate_content("ping", generation_config={"max_output_tokens": 4})
102
+ print("[handler:init] warm-up OK in", round((time.time()-t0)*1000), "ms", flush=True)
 
103
  except Exception as e:
104
+ print("[handler:init] warm-up failed (non-fatal):", repr(e), flush=True)
 
105
 
106
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
107
+ # If init failed, surface that to the client
108
+ if getattr(self, "_init_error", None):
109
+ return {"text":"", "debug":{"where":"init", "error": self._init_error}}
110
+
111
+ # Parameters override (optional)
112
+ params = data.get("parameters") or {}
113
+ gen_cfg = {
114
+ "temperature": float(params.get("temperature", GEN_CFG["temperature"])),
115
+ "top_p": float(params.get("top_p", GEN_CFG["top_p"])),
116
+ "max_output_tokens": int(params.get("max_output_tokens", GEN_CFG["max_output_tokens"])),
117
+ }
 
118
 
119
+ # Determine the user text
120
+ user_text = _pick_user_text(data) or ""
121
  if not user_text:
122
+ return {"text":"", "debug":{"where":"input", "error":"Empty prompt"}}
123
 
124
+ # Build history if provided (your local shape supported)
125
+ history = data.get("history") or []
126
+ gemini_history = _to_gemini_history(history)
127
 
128
+ # Try chat w/ a tiny retry to dodge transient cold-start
129
+ last_err = None
130
+ for attempt in range(2):
131
  try:
132
+ if gemini_history:
133
+ chat = self.model.start_chat(history=gemini_history)
134
+ resp = chat.send_message(user_text, generation_config=gen_cfg)
135
+ else:
136
+ # stateless call
137
+ resp = self.model.generate_content(user_text, generation_config=gen_cfg)
138
  txt = _extract_text(resp)
139
+ if txt:
140
+ return {"text": txt}
141
+ # no text — return finish reasons if present
142
+ fin = []
143
  try:
144
+ fin = [getattr(c, "finish_reason", None) for c in (resp.candidates or [])]
145
  except Exception:
146
+ pass
147
+ return {"text":"", "debug":{"where":"empty_text", "finish_reasons": fin}}
 
148
  except Exception as e:
149
+ last_err = repr(e)
150
+ time.sleep(0.35) # small backoff
151
 
152
+ return {"text":"", "debug":{"where":"exception", "exception": last_err}}