mila2030 commited on
Commit
cc24ffb
·
verified ·
1 Parent(s): e23fa9f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -57
handler.py CHANGED
@@ -1,30 +1,27 @@
1
- # handler.py — Hugging Face custom handler compatible with your local Flask shapes
2
- # - Env: GEMINI_API_KEY (or GOOGLE_API_KEY), GEMINI_MODEL, TEMPERATURE, TOP_P, MAX_OUTPUT_TOKENS, SYSTEM_PROMPT
3
- # - Input shapes supported:
4
- # 1) {"inputs": "hello", "parameters": {...}}
5
- # 2) {"message": "hello", "history": [...], "parameters": {...}} # your Flask shape
6
- # 3) {"inputs": {"messages":[{"role":"user","content":"hello"}]}, "parameters": {...}}
 
 
 
7
 
8
  from typing import Any, Dict, List, Optional
9
  import os, time
10
  import google.generativeai as genai
11
 
12
- # ---------- Config (env) ----------
13
  API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
14
- MODEL_NAME = os.getenv("GEMINI_MODEL", "models/gemini-2.5-pro") # keep your local default
15
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
16
  TOP_P = float(os.getenv("TOP_P", "0.95"))
17
  MAX_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "1024"))
18
  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")
19
 
20
- GEN_CFG = {
21
- "temperature": TEMPERATURE,
22
- "top_p": TOP_P,
23
- "max_output_tokens": MAX_TOKENS,
24
- }
25
 
26
  def _extract_text(resp: Any) -> str:
27
- # google.generativeai responses usually expose .text; keep a fallback
28
  if getattr(resp, "text", None):
29
  return resp.text
30
  try:
@@ -32,69 +29,69 @@ def _extract_text(resp: Any) -> str:
32
  content = getattr(c, "content", None)
33
  for p in getattr(content, "parts", []) or []:
34
  t = getattr(p, "text", None)
35
- if t:
36
- return t
37
  except Exception:
38
  pass
39
  return ""
40
 
41
  def _to_gemini_history(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
  """
43
- Accepts your local history shape:
44
- [{ "role":"user"|"model", "parts":[{"text":"..."}] }, ...]
45
- …or an OpenAI-ish/messages list:
46
- [{ "role":"user"|"assistant", "content":"..."}, ...]
47
- Produces Gemini-compatible history.
48
  """
49
  out: List[Dict[str, Any]] = []
50
  for m in history or []:
51
  role = (m.get("role") or "").lower()
52
- if "parts" in m: # already Gemini-ish
53
  parts = m.get("parts") or []
54
- out.append({"role": role if role in ("user","model") else "user", "parts": parts})
 
55
  else:
56
  content = m.get("content", "")
57
  if not isinstance(content, str):
58
- try:
59
- content = str(content)
60
- except Exception:
61
- content = ""
62
- # map "assistant" -> "model"
63
  role = "model" if role == "assistant" else ("user" if role != "model" else "model")
64
  out.append({"role": role, "parts": [{"text": content}]})
65
  return out
66
 
67
- def _pick_user_text(payload: Dict[str, Any]) -> Optional[str]:
68
- # 1) {"inputs": "text"}
69
- if isinstance(payload.get("inputs"), str):
70
- return payload["inputs"].strip()
71
-
72
- # 2) {"message":"text"}
73
- msg = payload.get("message")
74
- if isinstance(msg, str) and msg.strip():
75
- return msg.strip()
76
-
77
- # 3) {"inputs":{"messages":[...]}} pick last user message
78
- x = payload.get("inputs")
79
  if isinstance(x, dict):
80
  msgs = x.get("messages") or []
81
  for m in reversed(msgs):
82
  if (m.get("role","user").lower()) == "user":
83
  c = m.get("content","")
84
  return c if isinstance(c,str) else str(c)
85
-
86
  return None
87
 
88
  class EndpointHandler:
89
  def __init__(self, path: str = ""):
90
  if not API_KEY:
91
- self._init_error = "Missing GEMINI_API_KEY/GOOGLE_API_KEY in Environment Variables."
92
  print("[handler:init] ERROR:", self._init_error, flush=True)
93
  return
94
 
95
  self._init_error = None
96
  genai.configure(api_key=API_KEY)
97
- self.model = genai.GenerativeModel(MODEL_NAME, system_instruction=SYSTEM_PROMPT)
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Warm-up (non-fatal if it fails)
99
  try:
100
  t0 = time.time()
@@ -104,11 +101,10 @@ class EndpointHandler:
104
  print("[handler:init] warm-up failed (non-fatal):", repr(e), flush=True)
105
 
106
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
107
- # If init failed, surface that to the client
108
  if getattr(self, "_init_error", None):
109
- return {"text":"", "debug":{"where":"init", "error": self._init_error}}
110
 
111
- # Parameters override (optional)
112
  params = data.get("parameters") or {}
113
  gen_cfg = {
114
  "temperature": float(params.get("temperature", GEN_CFG["temperature"])),
@@ -116,16 +112,13 @@ class EndpointHandler:
116
  "max_output_tokens": int(params.get("max_output_tokens", GEN_CFG["max_output_tokens"])),
117
  }
118
 
119
- # Determine the user text
120
  user_text = _pick_user_text(data) or ""
121
  if not user_text:
122
- return {"text":"", "debug":{"where":"input", "error":"Empty prompt"}}
123
 
124
- # Build history if provided (your local shape supported)
125
- history = data.get("history") or []
126
- gemini_history = _to_gemini_history(history)
127
 
128
- # Try chat w/ a tiny retry to dodge transient cold-start
129
  last_err = None
130
  for attempt in range(2):
131
  try:
@@ -133,20 +126,18 @@ class EndpointHandler:
133
  chat = self.model.start_chat(history=gemini_history)
134
  resp = chat.send_message(user_text, generation_config=gen_cfg)
135
  else:
136
- # stateless call
137
  resp = self.model.generate_content(user_text, generation_config=gen_cfg)
138
  txt = _extract_text(resp)
139
  if txt:
140
  return {"text": txt}
141
- # no text — return finish reasons if present
142
  fin = []
143
  try:
144
  fin = [getattr(c, "finish_reason", None) for c in (resp.candidates or [])]
145
  except Exception:
146
  pass
147
- return {"text":"", "debug":{"where":"empty_text", "finish_reasons": fin}}
148
  except Exception as e:
149
  last_err = repr(e)
150
- time.sleep(0.35) # small backoff
151
 
152
- return {"text":"", "debug":{"where":"exception", "exception": last_err}}
 
1
+ # handler.py — custom Hugging Face handler compatible with your Flask shape
2
+ # Env needed in Endpoint Settings Environment Variables:
3
+ # GEMINI_API_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxx
4
+ # Optional:
5
+ # GEMINI_MODEL=gemini-2.5-pro
6
+ # TEMPERATURE=0.7
7
+ # TOP_P=0.95
8
+ # MAX_OUTPUT_TOKENS=1024
9
+ # SYSTEM_PROMPT=You are a helpful assistant.
10
 
11
  from typing import Any, Dict, List, Optional
12
  import os, time
13
  import google.generativeai as genai
14
 
 
15
  API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
16
+ MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.5-pro")
17
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
18
  TOP_P = float(os.getenv("TOP_P", "0.95"))
19
  MAX_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "1024"))
20
  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")
21
 
22
+ GEN_CFG = {"temperature": TEMPERATURE, "top_p": TOP_P, "max_output_tokens": MAX_TOKENS}
 
 
 
 
23
 
24
  def _extract_text(resp: Any) -> str:
 
25
  if getattr(resp, "text", None):
26
  return resp.text
27
  try:
 
29
  content = getattr(c, "content", None)
30
  for p in getattr(content, "parts", []) or []:
31
  t = getattr(p, "text", None)
32
+ if t: return t
 
33
  except Exception:
34
  pass
35
  return ""
36
 
37
  def _to_gemini_history(history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
38
  """
39
+ Supports your Flask history: [{role:'user'|'model', parts:[{text:'...'}]}]
40
+ and OpenAI-style: [{role:'user'|'assistant', content:'...'}]
 
 
 
41
  """
42
  out: List[Dict[str, Any]] = []
43
  for m in history or []:
44
  role = (m.get("role") or "").lower()
45
+ if "parts" in m: # already Gemini-like
46
  parts = m.get("parts") or []
47
+ role = role if role in ("user","model") else "user"
48
+ out.append({"role": role, "parts": parts})
49
  else:
50
  content = m.get("content", "")
51
  if not isinstance(content, str):
52
+ content = str(content)
 
 
 
 
53
  role = "model" if role == "assistant" else ("user" if role != "model" else "model")
54
  out.append({"role": role, "parts": [{"text": content}]})
55
  return out
56
 
57
+ def _pick_user_text(data: Dict[str, Any]) -> Optional[str]:
58
+ # Supports: {"message":"..."}, or {"inputs":"..."}, or {"inputs":{"messages":[...]}}
59
+ if isinstance(data.get("message"), str) and data["message"].strip():
60
+ return data["message"].strip()
61
+ if isinstance(data.get("inputs"), str) and data["inputs"].strip():
62
+ return data["inputs"].strip()
63
+ x = data.get("inputs")
 
 
 
 
 
64
  if isinstance(x, dict):
65
  msgs = x.get("messages") or []
66
  for m in reversed(msgs):
67
  if (m.get("role","user").lower()) == "user":
68
  c = m.get("content","")
69
  return c if isinstance(c,str) else str(c)
 
70
  return None
71
 
72
  class EndpointHandler:
73
  def __init__(self, path: str = ""):
74
  if not API_KEY:
75
+ self._init_error = "Missing GEMINI_API_KEY/GOOGLE_API_KEY."
76
  print("[handler:init] ERROR:", self._init_error, flush=True)
77
  return
78
 
79
  self._init_error = None
80
  genai.configure(api_key=API_KEY)
81
+
82
+ # Try model name, then a prefixed fallback
83
+ try:
84
+ self.model = genai.GenerativeModel(MODEL_NAME, system_instruction=SYSTEM_PROMPT)
85
+ except Exception as e1:
86
+ alt = MODEL_NAME.replace("models/", "") if MODEL_NAME.startswith("models/") else f"models/{MODEL_NAME}"
87
+ try:
88
+ self.model = genai.GenerativeModel(alt, system_instruction=SYSTEM_PROMPT)
89
+ print(f"[handler:init] Fallback model name used: {alt}", flush=True)
90
+ except Exception as e2:
91
+ self._init_error = f"Model init failed: {repr(e1)} | fallback: {repr(e2)}"
92
+ print("[handler:init] ERROR:", self._init_error, flush=True)
93
+ return
94
+
95
  # Warm-up (non-fatal if it fails)
96
  try:
97
  t0 = time.time()
 
101
  print("[handler:init] warm-up failed (non-fatal):", repr(e), flush=True)
102
 
103
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
104
  if getattr(self, "_init_error", None):
105
+ return {"text": "", "debug": {"where": "init", "error": self._init_error}}
106
 
107
+ # Per-request overrides
108
  params = data.get("parameters") or {}
109
  gen_cfg = {
110
  "temperature": float(params.get("temperature", GEN_CFG["temperature"])),
 
112
  "max_output_tokens": int(params.get("max_output_tokens", GEN_CFG["max_output_tokens"])),
113
  }
114
 
 
115
  user_text = _pick_user_text(data) or ""
116
  if not user_text:
117
+ return {"text": "", "debug": {"where": "input", "error": "Empty prompt"}}
118
 
119
+ gemini_history = _to_gemini_history(data.get("history") or [])
 
 
120
 
121
+ # Try once; tiny retry handles cold-start/transient
122
  last_err = None
123
  for attempt in range(2):
124
  try:
 
126
  chat = self.model.start_chat(history=gemini_history)
127
  resp = chat.send_message(user_text, generation_config=gen_cfg)
128
  else:
 
129
  resp = self.model.generate_content(user_text, generation_config=gen_cfg)
130
  txt = _extract_text(resp)
131
  if txt:
132
  return {"text": txt}
 
133
  fin = []
134
  try:
135
  fin = [getattr(c, "finish_reason", None) for c in (resp.candidates or [])]
136
  except Exception:
137
  pass
138
+ return {"text": "", "debug": {"where": "empty_text", "finish_reasons": fin}}
139
  except Exception as e:
140
  last_err = repr(e)
141
+ time.sleep(0.35)
142
 
143
+ return {"text": "", "debug": {"where": "exception", "exception": last_err}}