Spaces:
Running
Running
Commit
Β·
94aee85
1
Parent(s):
f2a2584
Update llm.py
Browse files
llm.py
CHANGED
|
@@ -21,15 +21,11 @@ api_key = os.getenv("MODEL_API_KEY")
|
|
| 21 |
client = openai.OpenAI(api_key=api_key)
|
| 22 |
bedrock_runtime = boto3.client(
|
| 23 |
"bedrock-runtime",
|
| 24 |
-
region_name="us-
|
| 25 |
-
aws_access_key_id="
|
| 26 |
-
aws_secret_access_key="AWS_SECRET_ACCESS_KEY"
|
| 27 |
)
|
| 28 |
|
| 29 |
-
# Ollama configuration
|
| 30 |
-
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 31 |
-
|
| 32 |
-
|
| 33 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 34 |
# Model switcher
|
| 35 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -170,140 +166,140 @@ def chat(messages, persona):
|
|
| 170 |
total_tok = len(text.split())
|
| 171 |
|
| 172 |
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 173 |
-
elif provider == "meta":
|
| 174 |
-
|
| 175 |
-
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
dt = time.time() - t0
|
| 210 |
-
body = json.loads(response["body"].read())
|
| 211 |
-
text = body.get("generation", "").strip()
|
| 212 |
-
total_tok = len(text.split())
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
elif provider == "mistral":
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
elif provider == "ollama":
|
| 239 |
-
|
| 240 |
-
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
|
| 275 |
-
|
| 276 |
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
|
| 285 |
|
| 286 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 287 |
# Diagnostics / CLI test
|
| 288 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 289 |
def check_credentials():
|
| 290 |
-
# Check if using Ollama (no API key required)
|
| 291 |
-
if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
|
| 305 |
# Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
|
| 306 |
-
bedrock_providers = ["anthropic"
|
| 307 |
if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
|
| 308 |
# Test AWS Bedrock connection by trying to invoke a simple model
|
| 309 |
try:
|
|
|
|
| 21 |
client = openai.OpenAI(api_key=api_key)
|
| 22 |
bedrock_runtime = boto3.client(
|
| 23 |
"bedrock-runtime",
|
| 24 |
+
region_name="us-east-1",
|
| 25 |
+
aws_access_key_id= os.getenv("AWS_ACCESS_ID")
|
| 26 |
+
aws_secret_access_key= os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
# Model switcher
|
| 31 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 166 |
total_tok = len(text.split())
|
| 167 |
|
| 168 |
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 169 |
+
# elif provider == "meta":
|
| 170 |
+
# print("Using meta (LLaMA): ", MODEL_STRING)
|
| 171 |
+
# t0 = time.time()
|
| 172 |
|
| 173 |
+
# # Add system prompt for better behavior
|
| 174 |
+
# system_prompt = ""
|
| 175 |
|
| 176 |
+
# # Format conversation properly for Llama3
|
| 177 |
+
# formatted_prompt = "<|begin_of_text|>"
|
| 178 |
|
| 179 |
+
# # Add system prompt
|
| 180 |
+
# formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n"
|
| 181 |
|
| 182 |
+
# # Add conversation history
|
| 183 |
+
# for msg in messages:
|
| 184 |
+
# if msg["role"] == "user":
|
| 185 |
+
# formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
|
| 186 |
+
# elif msg["role"] == "assistant":
|
| 187 |
+
# formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
|
| 188 |
|
| 189 |
+
# # Add final assistant prompt
|
| 190 |
+
# formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
|
| 191 |
+
|
| 192 |
+
# response = bedrock_runtime.invoke_model(
|
| 193 |
+
# modelId=MODEL_STRING,
|
| 194 |
+
# contentType="application/json",
|
| 195 |
+
# accept="application/json",
|
| 196 |
+
# body=json.dumps(
|
| 197 |
+
# {
|
| 198 |
+
# "prompt": formatted_prompt,
|
| 199 |
+
# "max_gen_len": 512, # Shorter responses
|
| 200 |
+
# "temperature": 0.3, # Lower temperature for more focused responses
|
| 201 |
+
# }
|
| 202 |
+
# ),
|
| 203 |
+
# )
|
| 204 |
+
|
| 205 |
+
# dt = time.time() - t0
|
| 206 |
+
# body = json.loads(response["body"].read())
|
| 207 |
+
# text = body.get("generation", "").strip()
|
| 208 |
+
# total_tok = len(text.split())
|
| 209 |
+
|
| 210 |
+
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 211 |
+
# elif provider == "mistral":
|
| 212 |
+
# print("Using mistral: ", MODEL_STRING)
|
| 213 |
+
# t0 = time.time()
|
| 214 |
+
|
| 215 |
+
# prompt = messages[-1]["content"]
|
| 216 |
+
# formatted_prompt = f"<s>[INST] {prompt} [/INST]"
|
| 217 |
+
|
| 218 |
+
# response = bedrock_runtime.invoke_model(
|
| 219 |
+
# modelId=MODEL_STRING,
|
| 220 |
+
# contentType="application/json",
|
| 221 |
+
# accept="application/json",
|
| 222 |
+
# body=json.dumps(
|
| 223 |
+
# {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
|
| 224 |
+
# ),
|
| 225 |
+
# )
|
| 226 |
+
|
| 227 |
+
# dt = time.time() - t0
|
| 228 |
+
# body = json.loads(response["body"].read())
|
| 229 |
+
|
| 230 |
+
# text = body["outputs"][0]["text"].strip()
|
| 231 |
+
# total_tok = len(text.split())
|
| 232 |
+
|
| 233 |
+
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 234 |
+
# elif provider == "ollama":
|
| 235 |
+
# print("Using ollama: ", MODEL_STRING)
|
| 236 |
+
# t0 = time.time()
|
| 237 |
|
| 238 |
+
# # Format messages for Ollama API with system prompt
|
| 239 |
+
# ollama_messages = []
|
| 240 |
|
| 241 |
+
# # Add system prompt for better behavior
|
| 242 |
+
# system_prompt = ""
|
| 243 |
+
# ollama_messages.append({
|
| 244 |
+
# "role": "system",
|
| 245 |
+
# "content": system_prompt
|
| 246 |
+
# })
|
| 247 |
|
| 248 |
+
# for msg in messages:
|
| 249 |
+
# ollama_messages.append({
|
| 250 |
+
# "role": msg["role"],
|
| 251 |
+
# "content": msg["content"]
|
| 252 |
+
# })
|
| 253 |
|
| 254 |
+
# # Make request to Ollama API
|
| 255 |
+
# response = requests.post(
|
| 256 |
+
# f"{OLLAMA_BASE_URL}/api/chat",
|
| 257 |
+
# json={
|
| 258 |
+
# "model": MODEL_STRING,
|
| 259 |
+
# "messages": ollama_messages,
|
| 260 |
+
# "stream": False,
|
| 261 |
+
# "options": {
|
| 262 |
+
# "temperature": 0.3, # Lower temperature for more focused responses
|
| 263 |
+
# # "num_predict": 4000, # Much higher limit for longer responses
|
| 264 |
+
# "top_p": 0.9,
|
| 265 |
+
# "repeat_penalty": 1.1
|
| 266 |
+
# }
|
| 267 |
+
# },
|
| 268 |
+
# timeout=60
|
| 269 |
+
# )
|
| 270 |
|
| 271 |
+
# dt = time.time() - t0
|
| 272 |
|
| 273 |
+
# if response.status_code == 200:
|
| 274 |
+
# result = response.json()
|
| 275 |
+
# text = result["message"]["content"].strip()
|
| 276 |
+
# total_tok = len(text.split())
|
| 277 |
+
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
|
| 278 |
+
# else:
|
| 279 |
+
# raise Exception(f"Ollama API error: {response.status_code} - {response.text}")
|
| 280 |
|
| 281 |
|
| 282 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
# Diagnostics / CLI test
|
| 284 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 285 |
def check_credentials():
|
| 286 |
+
# # Check if using Ollama (no API key required)
|
| 287 |
+
# if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
|
| 288 |
+
# # Test Ollama connection
|
| 289 |
+
# try:
|
| 290 |
+
# response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5)
|
| 291 |
+
# if response.status_code == 200:
|
| 292 |
+
# print("Ollama connection successful")
|
| 293 |
+
# return True
|
| 294 |
+
# else:
|
| 295 |
+
# print(f"Ollama connection failed: {response.status_code}")
|
| 296 |
+
# return False
|
| 297 |
+
# except Exception as e:
|
| 298 |
+
# print(f"Ollama connection failed: {e}")
|
| 299 |
+
# return False
|
| 300 |
|
| 301 |
# Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
|
| 302 |
+
bedrock_providers = ["anthropic"]
|
| 303 |
if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
|
| 304 |
# Test AWS Bedrock connection by trying to invoke a simple model
|
| 305 |
try:
|