pythonprincess commited on
Commit
196c49c
·
verified ·
1 Parent(s): 40a22b2

Upload 2 files

Browse files
models/gemma/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Gemma Conversational AI Model Package
2
+
models/gemma/gemma_utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/gemma/gemma_utils.py
2
+
3
+ """
4
+ Gemma Model Utilities for PENNY Project
5
+ Handles text generation using the Gemma-based core language model via Hugging Face Inference API.
6
+ Provides async generation with structured error handling and logging.
7
+ """
8
+
9
+ import os
10
+ import asyncio
11
+ import time
12
+ import httpx
13
+ from typing import Dict, Any, Optional
14
+
15
+ # --- Logging Imports ---
16
+ from app.logging_utils import log_interaction, sanitize_for_logging
17
+
18
+ # --- Configuration ---
19
+ HF_API_URL = "https://api-inference.huggingface.co/models/google/gemma-7b-it"
20
+ DEFAULT_TIMEOUT = 30.0 # Gemma can take longer to respond
21
+ MAX_RETRIES = 2
22
+ AGENT_NAME = "penny-core-agent"
23
+
24
+
25
+ def is_gemma_available() -> bool:
26
+ """
27
+ Check if Gemma service is available.
28
+
29
+ Returns:
30
+ bool: True if HF_TOKEN is configured.
31
+ """
32
+ return bool(os.getenv("HF_TOKEN"))
33
+
34
+
35
+ async def generate_response(
36
+ prompt: str,
37
+ max_new_tokens: int = 256,
38
+ temperature: float = 0.7,
39
+ tenant_id: Optional[str] = None,
40
+ ) -> Dict[str, Any]:
41
+ """
42
+ Runs text generation using Gemma via Hugging Face Inference API.
43
+
44
+ Args:
45
+ prompt: The conversational or instruction prompt.
46
+ max_new_tokens: The maximum number of tokens to generate (default: 256).
47
+ temperature: Controls randomness in generation (default: 0.7).
48
+ tenant_id: Optional tenant identifier for logging.
49
+
50
+ Returns:
51
+ A dictionary containing:
52
+ - response (str): The generated text
53
+ - available (bool): Whether the service was available
54
+ - error (str, optional): Error message if generation failed
55
+ - response_time_ms (int, optional): Generation time in milliseconds
56
+ """
57
+ start_time = time.time()
58
+
59
+ # Check API token availability
60
+ HF_TOKEN = os.getenv("HF_TOKEN")
61
+ if not HF_TOKEN:
62
+ log_interaction(
63
+ intent="gemma_generate",
64
+ tenant_id=tenant_id,
65
+ success=False,
66
+ error="HF_TOKEN not configured",
67
+ fallback_used=True
68
+ )
69
+ return {
70
+ "response": "I'm having trouble accessing my language model right now. Please try again in a moment!",
71
+ "available": False,
72
+ "error": "HF_TOKEN not configured"
73
+ }
74
+
75
+ # Validate inputs
76
+ if not prompt or not isinstance(prompt, str):
77
+ log_interaction(
78
+ intent="gemma_generate",
79
+ tenant_id=tenant_id,
80
+ success=False,
81
+ error="Invalid prompt provided"
82
+ )
83
+ return {
84
+ "response": "I didn't receive a valid prompt. Could you try again?",
85
+ "available": True,
86
+ "error": "Invalid input"
87
+ }
88
+
89
+ # Configure generation parameters
90
+ payload = {
91
+ "inputs": prompt,
92
+ "parameters": {
93
+ "max_new_tokens": max_new_tokens,
94
+ "temperature": temperature,
95
+ "do_sample": True if temperature > 0.0 else False,
96
+ "return_full_text": False
97
+ }
98
+ }
99
+
100
+ headers = {
101
+ "Authorization": f"Bearer {HF_TOKEN}",
102
+ "Content-Type": "application/json"
103
+ }
104
+
105
+ # Retry logic for API calls
106
+ for attempt in range(MAX_RETRIES):
107
+ try:
108
+ async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
109
+ response = await client.post(HF_API_URL, json=payload, headers=headers)
110
+ response.raise_for_status()
111
+ result = response.json()
112
+
113
+ response_time_ms = int((time.time() - start_time) * 1000)
114
+
115
+ # Parse response
116
+ if isinstance(result, list) and len(result) > 0:
117
+ generated_text = result[0].get("generated_text", "").strip()
118
+
119
+ # Log slow responses
120
+ if response_time_ms > 5000:
121
+ log_interaction(
122
+ intent="gemma_generate_slow",
123
+ tenant_id=tenant_id,
124
+ success=True,
125
+ response_time_ms=response_time_ms,
126
+ details="Slow generation detected"
127
+ )
128
+
129
+ log_interaction(
130
+ intent="gemma_generate",
131
+ tenant_id=tenant_id,
132
+ success=True,
133
+ response_time_ms=response_time_ms,
134
+ prompt_preview=sanitize_for_logging(prompt[:100])
135
+ )
136
+
137
+ return {
138
+ "response": generated_text,
139
+ "available": True,
140
+ "response_time_ms": response_time_ms
141
+ }
142
+
143
+ # Unexpected output format
144
+ log_interaction(
145
+ intent="gemma_generate",
146
+ tenant_id=tenant_id,
147
+ success=False,
148
+ error="Unexpected API response format",
149
+ response_time_ms=response_time_ms
150
+ )
151
+
152
+ return {
153
+ "response": "I got an unexpected response from my language model. Let me try to help you another way!",
154
+ "available": True,
155
+ "error": "Unexpected output format"
156
+ }
157
+
158
+ except httpx.TimeoutException:
159
+ if attempt < MAX_RETRIES - 1:
160
+ await asyncio.sleep(1) # Wait before retry
161
+ continue
162
+
163
+ response_time_ms = int((time.time() - start_time) * 1000)
164
+ log_interaction(
165
+ intent="gemma_generate",
166
+ tenant_id=tenant_id,
167
+ success=False,
168
+ error="API timeout after retries",
169
+ response_time_ms=response_time_ms
170
+ )
171
+
172
+ return {
173
+ "response": "I'm taking too long to respond. Please try again!",
174
+ "available": False,
175
+ "error": "Timeout",
176
+ "response_time_ms": response_time_ms
177
+ }
178
+
179
+ except httpx.HTTPStatusError as e:
180
+ response_time_ms = int((time.time() - start_time) * 1000)
181
+ log_interaction(
182
+ intent="gemma_generate",
183
+ tenant_id=tenant_id,
184
+ success=False,
185
+ error=f"HTTP {e.response.status_code}",
186
+ response_time_ms=response_time_ms
187
+ )
188
+
189
+ return {
190
+ "response": "I'm having trouble generating a response right now. Please try again!",
191
+ "available": False,
192
+ "error": f"HTTP {e.response.status_code}",
193
+ "response_time_ms": response_time_ms
194
+ }
195
+
196
+ except Exception as e:
197
+ if attempt < MAX_RETRIES - 1:
198
+ await asyncio.sleep(1)
199
+ continue
200
+
201
+ response_time_ms = int((time.time() - start_time) * 1000)
202
+ log_interaction(
203
+ intent="gemma_generate",
204
+ tenant_id=tenant_id,
205
+ success=False,
206
+ error=str(e),
207
+ response_time_ms=response_time_ms,
208
+ fallback_used=True
209
+ )
210
+
211
+ return {
212
+ "response": "I'm having trouble generating a response right now. Please try again!",
213
+ "available": False,
214
+ "error": str(e),
215
+ "response_time_ms": response_time_ms
216
+ }