mila2030 commited on
Commit
aacab3f
·
verified ·
1 Parent(s): 505b185

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +95 -114
handler.py CHANGED
@@ -1,31 +1,52 @@
1
- import os
 
2
  from typing import Any, Dict, List, Union
 
3
  import google.generativeai as genai
4
 
5
- DEFAULT_MODEL = os.getenv("GEMINI_MODEL", "models/gemini-2.5-pro")
6
- DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
7
- DEFAULT_TOP_P = float(os.getenv("TOP_P", "0.95"))
8
- DEFAULT_MAX_OUTPUT = int(os.getenv("MAX_OUTPUT_TOKENS", "1024"))
9
- DEFAULT_CANDIDATE_COUNT = int(os.getenv("CANDIDATE_COUNT", "1"))
10
- DEFAULT_SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")
11
- USE_HISTORY = os.getenv("USE_HISTORY", "true").lower() in {"1", "true", "yes"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  class EndpointHandler:
14
- def __init__(self, model_dir: str, *args, **kwargs):
15
  api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
16
  if not api_key:
17
- self._init_error = "Missing GEMINI_API_KEY in Endpoint Environment Variables."
 
18
  return
19
  self._init_error = None
20
  genai.configure(api_key=api_key)
21
 
22
- # Proper system instruction
23
- self.model = genai.GenerativeModel(
24
- DEFAULT_MODEL,
25
- system_instruction=DEFAULT_SYSTEM_PROMPT
26
- )
27
 
28
- # Slightly relaxed safety (optional)
29
  self.safety_settings = None
30
  try:
31
  from google.generativeai.types import HarmBlockThreshold, HarmCategory
@@ -35,104 +56,64 @@ class EndpointHandler:
35
  HarmCategory.HARM_CATEGORY_SEXUAL: HarmBlockThreshold.BLOCK_ONLY_HIGH,
36
  HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH,
37
  }
38
- except Exception:
39
- pass
40
-
41
- # ---- Helpers ----
42
- def _to_gemini_history(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
43
- # Only user/model roles are valid for chat history.
44
- out = []
45
- for m in messages:
46
- role = (m.get("role") or "user").lower()
47
- if role == "assistant":
48
- role = "model"
49
- if role not in ("user", "model"):
50
- role = "user"
51
- out.append({"role": role, "parts": [{"text": str(m.get("content", ""))}]})
52
- return out
53
-
54
- def _extract_text(self, resp: Any) -> Dict[str, Any]:
55
- # 1) Standard
56
- if getattr(resp, "text", None):
57
- return {"text": resp.text}
58
- # 2) Candidates/parts
59
- try:
60
- cands = getattr(resp, "candidates", None) or []
61
- for c in cands:
62
- if getattr(c, "content", None) and getattr(c.content, "parts", None):
63
- for p in c.content.parts:
64
- t = getattr(p, "text", None)
65
- if t:
66
- return {"text": t}
67
- except Exception:
68
- pass
69
- # 3) Diagnostics fallback
70
- diag = {}
71
- try:
72
- if cands:
73
- fr = [getattr(c, "finish_reason", None) for c in cands]
74
- diag["finish_reasons"] = fr
75
- except Exception:
76
- pass
77
- return {"text": "I couldn’t generate a response.", "debug": diag or {"note": "empty model text"}}
78
-
79
- def _gen_cfg(self, payload: Dict[str, Any]) -> Dict[str, Any]:
80
- params = payload.get("parameters") or {}
81
- return {
82
- "temperature": float(params.get("temperature", DEFAULT_TEMPERATURE)),
83
- "top_p": float(params.get("top_p", DEFAULT_TOP_P)),
84
- "max_output_tokens": int(params.get("max_output_tokens", DEFAULT_MAX_OUTPUT)),
85
- "candidate_count": int(params.get("candidate_count", DEFAULT_CANDIDATE_COUNT)),
86
- }
87
-
88
- # ---- Main entry ----
89
- def __call__(self, data: Union[Dict[str, Any], List[Dict[str, Any]]]):
90
  if self._init_error:
91
- return {"error": self._init_error}
 
92
  try:
93
- if isinstance(data, list):
94
- return [self._handle_one(d) for d in data]
95
- return self._handle_one(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  except Exception as e:
97
- return {"error": str(e)}
98
-
99
- def _handle_one(self, payload: Dict[str, Any]) -> Dict[str, Any]:
100
- cfg = self._gen_cfg(payload)
101
- data_inputs = payload.get("inputs")
102
-
103
- # A) Chat: {"inputs":{"messages":[...]}}
104
- if isinstance(data_inputs, dict) and "messages" in data_inputs:
105
- msgs = data_inputs["messages"] or []
106
- # Stateless mode if USE_HISTORY=false
107
- if not USE_HISTORY:
108
- # use only the last user message
109
- last_user = next((m for m in reversed(msgs) if (m.get("role") or "user").lower() == "user"), None)
110
- text = (last_user or {}).get("content", "") if last_user else ""
111
- resp = self.model.generate_content(text, generation_config=cfg, safety_settings=self.safety_settings)
112
- return self._extract_text(resp)
113
-
114
- # With history
115
- last_user = next((m for m in reversed(msgs) if (m.get("role") or "user").lower() == "user"), None)
116
- user_text = (last_user or {}).get("content", "")
117
- history_msgs = [m for m in msgs if m is not last_user]
118
- chat = self.model.start_chat(history=self._to_gemini_history(history_msgs))
119
- resp = chat.send_message(user_text, generation_config=cfg, safety_settings=self.safety_settings)
120
- return self._extract_text(resp)
121
-
122
- # B) Plain text: {"inputs":"..."}
123
- if isinstance(data_inputs, str):
124
- prompt = data_inputs.strip()
125
- if not prompt:
126
- return {"text": "Empty prompt."}
127
- resp = self.model.generate_content(prompt, generation_config=cfg, safety_settings=self.safety_settings)
128
- return self._extract_text(resp)
129
-
130
- # C) Fallbacks
131
- if "messages" in (payload or {}):
132
- msgs = payload["messages"] or []
133
- last_user = next((m for m in reversed(msgs) if (m.get("role") or "user").lower() == "user"), None)
134
- text = (last_user or {}).get("content", "")
135
- resp = self.model.generate_content(text, generation_config=cfg, safety_settings=self.safety_settings)
136
- return self._extract_text(resp)
137
-
138
- return {"text": "Empty prompt."}
 
1
+ # handler.py — HF-compliant, stateless Gemini proxy
2
+
3
  from typing import Any, Dict, List, Union
4
+ import os
5
  import google.generativeai as genai
6
 
7
+ # Config via HF Endpoint → Settings → Environment Variables
8
+ MODEL = os.getenv("GEMINI_MODEL", "gemini-1.5-flash") # safe default
9
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
10
+ TOP_P = float(os.getenv("TOP_P", "0.95"))
11
+ MAX_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "512"))
12
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")
13
+
14
+ def _extract_text(resp: Any) -> str:
15
+ # 1) standard property
16
+ if getattr(resp, "text", None):
17
+ return resp.text
18
+ # 2) candidates/parts
19
+ try:
20
+ for c in getattr(resp, "candidates", []) or []:
21
+ content = getattr(c, "content", None)
22
+ for p in getattr(content, "parts", []) or []:
23
+ t = getattr(p, "text", None)
24
+ if t:
25
+ return t
26
+ except Exception:
27
+ pass
28
+ return ""
29
+
30
+ def _last_user_from_messages(msgs: List[Dict[str, Any]]) -> str:
31
+ for m in reversed(msgs or []):
32
+ if (m.get("role") or "user").lower() == "user":
33
+ return str(m.get("content", "")).strip()
34
+ return ""
35
 
36
  class EndpointHandler:
37
+ def __init__(self, path: str = ""):
38
  api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
39
  if not api_key:
40
+ self._init_error = "Missing GEMINI_API_KEY/GOOGLE_API_KEY in Endpoint Environment Variables."
41
+ print("[handler:init] ERROR:", self._init_error, flush=True)
42
  return
43
  self._init_error = None
44
  genai.configure(api_key=api_key)
45
 
46
+ # Proper system prompt
47
+ self.model = genai.GenerativeModel(MODEL, system_instruction=SYSTEM_PROMPT)
 
 
 
48
 
49
+ # Optional: slightly relaxed safety to avoid silent blocks of normal prompts
50
  self.safety_settings = None
51
  try:
52
  from google.generativeai.types import HarmBlockThreshold, HarmCategory
 
56
  HarmCategory.HARM_CATEGORY_SEXUAL: HarmBlockThreshold.BLOCK_ONLY_HIGH,
57
  HarmCategory.HARM_CATEGORY_DANGEROUS: HarmBlockThreshold.BLOCK_ONLY_HIGH,
58
  }
59
+ except Exception as e:
60
+ print("[handler:init] safety settings skipped:", repr(e), flush=True)
61
+
62
+ print(f"[handler:init] OK MODEL={MODEL}", flush=True)
63
+
64
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  if self._init_error:
66
+ return {"text": "", "debug": {"error": self._init_error}}
67
+
68
  try:
69
+ print("[handler:call] payload:", data, flush=True)
70
+ # HF guarantees top-level `inputs`
71
+ inputs = data.get("inputs")
72
+
73
+ # Accept either:
74
+ # A) {"inputs":"plain text"}
75
+ # B) {"inputs":{"messages":[{"role":"user","content":"..."}]}}
76
+ # (compat) Also accept top-level "messages" if present.
77
+ if isinstance(inputs, str):
78
+ user_text = inputs.strip()
79
+ elif isinstance(inputs, dict) and "messages" in inputs:
80
+ user_text = _last_user_from_messages(inputs.get("messages"))
81
+ elif "messages" in data:
82
+ user_text = _last_user_from_messages(data.get("messages"))
83
+ else:
84
+ user_text = ""
85
+
86
+ if not user_text:
87
+ return {"text": "", "debug": {"note": "Empty prompt."}}
88
+
89
+ gen_cfg = {
90
+ "temperature": TEMPERATURE,
91
+ "top_p": TOP_P,
92
+ "max_output_tokens": MAX_TOKENS,
93
+ }
94
+
95
+ print("[handler:call] generate_content:", repr(user_text[:160]), flush=True)
96
+ resp = self.model.generate_content(
97
+ user_text,
98
+ generation_config=gen_cfg,
99
+ safety_settings=self.safety_settings
100
+ )
101
+ print("[handler:call] raw resp:", repr(resp), flush=True)
102
+
103
+ text = _extract_text(resp)
104
+ if text:
105
+ return {"text": text}
106
+
107
+ # Diagnostics if empty
108
+ debug = {}
109
+ try:
110
+ fr = [getattr(c, "finish_reason", None) for c in (resp.candidates or [])]
111
+ if fr:
112
+ debug["finish_reasons"] = fr
113
+ except Exception:
114
+ pass
115
+ return {"text": "", "debug": debug or {"note": "Empty model text"}}
116
+
117
  except Exception as e:
118
+ print("[handler:call] EXC:", repr(e), flush=True)
119
+ return {"text": "", "debug": {"exception": str(e)}}