Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
0544709
1
Parent(s):
fd9d58a
update model_operations.py for new llms
Browse files- src/backend/model_operations.py +44 -16
src/backend/model_operations.py
CHANGED
|
@@ -164,7 +164,7 @@ class SummaryGenerator:
|
|
| 164 |
using_replicate_api = False
|
| 165 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
| 166 |
using_pipeline = False
|
| 167 |
-
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo']
|
| 168 |
|
| 169 |
for replicate_api_model in replicate_api_models:
|
| 170 |
if replicate_api_model in self.model_id.lower():
|
|
@@ -222,6 +222,7 @@ class SummaryGenerator:
|
|
| 222 |
print(result)
|
| 223 |
return result
|
| 224 |
|
|
|
|
| 225 |
elif 'grok' in self.model_id.lower(): # xai
|
| 226 |
XAI_API_KEY = os.getenv("XAI_API_KEY")
|
| 227 |
client = OpenAI(
|
|
@@ -241,6 +242,7 @@ class SummaryGenerator:
|
|
| 241 |
print(result)
|
| 242 |
return result
|
| 243 |
|
|
|
|
| 244 |
elif 'gemini' in self.model_id.lower():
|
| 245 |
vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
|
| 246 |
model = GenerativeModel(
|
|
@@ -249,7 +251,7 @@ class SummaryGenerator:
|
|
| 249 |
)
|
| 250 |
generation_config = {
|
| 251 |
"temperature": 0,
|
| 252 |
-
"max_output_tokens":
|
| 253 |
}
|
| 254 |
safety_settings = [
|
| 255 |
SafetySetting(
|
|
@@ -277,6 +279,8 @@ class SummaryGenerator:
|
|
| 277 |
result = response.text
|
| 278 |
print(result)
|
| 279 |
return result
|
|
|
|
|
|
|
| 280 |
elif using_replicate_api:
|
| 281 |
print("using replicate")
|
| 282 |
if 'snowflake' in self.model_id.lower():
|
|
@@ -306,6 +310,7 @@ class SummaryGenerator:
|
|
| 306 |
print(response)
|
| 307 |
return response
|
| 308 |
|
|
|
|
| 309 |
elif 'claude' in self.model_id.lower(): # using anthropic api
|
| 310 |
print('using Anthropic API')
|
| 311 |
client = anthropic.Anthropic()
|
|
@@ -331,6 +336,7 @@ class SummaryGenerator:
|
|
| 331 |
print(result)
|
| 332 |
return result
|
| 333 |
|
|
|
|
| 334 |
elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
|
| 335 |
co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
|
| 336 |
response = co.chat(
|
|
@@ -345,6 +351,7 @@ class SummaryGenerator:
|
|
| 345 |
print(result)
|
| 346 |
return result
|
| 347 |
|
|
|
|
| 348 |
elif 'mistral-large' in self.model_id.lower():
|
| 349 |
api_key = os.environ["MISTRAL_API_KEY"]
|
| 350 |
client = Mistral(api_key=api_key)
|
|
@@ -369,6 +376,7 @@ class SummaryGenerator:
|
|
| 369 |
print(result)
|
| 370 |
return result
|
| 371 |
|
|
|
|
| 372 |
elif 'deepseek' in self.model_id.lower():
|
| 373 |
client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com")
|
| 374 |
response = client.chat.completions.create(
|
|
@@ -385,20 +393,21 @@ class SummaryGenerator:
|
|
| 385 |
print(result)
|
| 386 |
return result
|
| 387 |
|
| 388 |
-
# Using HF
|
| 389 |
elif self.local_model is None and self.local_pipeline is None:
|
| 390 |
if using_pipeline:
|
| 391 |
self.local_pipeline = pipeline(
|
| 392 |
"text-generation",
|
| 393 |
model=self.model_id,
|
| 394 |
tokenizer=AutoTokenizer.from_pretrained(self.model_id),
|
| 395 |
-
torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() else "auto",
|
| 396 |
device_map="auto",
|
| 397 |
trust_remote_code=True
|
| 398 |
)
|
| 399 |
else:
|
| 400 |
if 'ragamuffin' in self.model_id.lower():
|
| 401 |
self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
|
|
|
|
| 402 |
else:
|
| 403 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
| 404 |
print("Tokenizer loaded")
|
|
@@ -420,7 +429,12 @@ class SummaryGenerator:
|
|
| 420 |
# self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
|
| 421 |
# torch_dtype=torch.bfloat16, # forcing bfloat16 for now
|
| 422 |
# attn_implementation="flash_attention_2")
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
else:
|
| 425 |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
|
| 426 |
# print(self.local_model.device)
|
|
@@ -435,7 +449,7 @@ class SummaryGenerator:
|
|
| 435 |
]
|
| 436 |
outputs = self.local_pipeline(
|
| 437 |
messages,
|
| 438 |
-
max_new_tokens=
|
| 439 |
# return_full_text=False,
|
| 440 |
do_sample=False
|
| 441 |
)
|
|
@@ -445,6 +459,8 @@ class SummaryGenerator:
|
|
| 445 |
|
| 446 |
elif self.local_model: # cannot call API. using local model / pipeline
|
| 447 |
print('Using local model')
|
|
|
|
|
|
|
| 448 |
if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
|
| 449 |
messages=[
|
| 450 |
# gemma-1.1, mistral-7b does not accept system role
|
|
@@ -478,29 +494,41 @@ class SummaryGenerator:
|
|
| 478 |
{"role": "system", "content": system_prompt},
|
| 479 |
{"role": "user", "content": user_prompt}
|
| 480 |
]
|
| 481 |
-
prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
|
| 482 |
-
|
| 483 |
-
#
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
if 'granite' in self.model_id.lower():
|
| 486 |
self.local_model.eval()
|
| 487 |
outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
else:
|
| 489 |
with torch.no_grad():
|
| 490 |
outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
|
| 491 |
if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
|
| 492 |
outputs = outputs[:, input_ids['input_ids'].shape[1]:]
|
| 493 |
-
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
| 494 |
outputs = [
|
| 495 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
|
| 496 |
]
|
| 497 |
-
|
|
|
|
| 498 |
if 'qwen2-vl' in self.model_id.lower():
|
| 499 |
result = self.processor.batch_decode(
|
| 500 |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 501 |
)[0]
|
| 502 |
-
|
| 503 |
-
|
| 504 |
else:
|
| 505 |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 506 |
|
|
@@ -512,9 +540,9 @@ class SummaryGenerator:
|
|
| 512 |
result = result.split(messages[-1]['content'])[1].strip()
|
| 513 |
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
| 514 |
pass
|
|
|
|
|
|
|
| 515 |
else:
|
| 516 |
-
# print(prompt)
|
| 517 |
-
# print('-'*50)
|
| 518 |
result = result.replace(prompt.strip(), '')
|
| 519 |
|
| 520 |
print(result)
|
|
|
|
| 164 |
using_replicate_api = False
|
| 165 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
| 166 |
using_pipeline = False
|
| 167 |
+
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo', 'llama-3.3']
|
| 168 |
|
| 169 |
for replicate_api_model in replicate_api_models:
|
| 170 |
if replicate_api_model in self.model_id.lower():
|
|
|
|
| 222 |
print(result)
|
| 223 |
return result
|
| 224 |
|
| 225 |
+
# Using Grok API
|
| 226 |
elif 'grok' in self.model_id.lower(): # xai
|
| 227 |
XAI_API_KEY = os.getenv("XAI_API_KEY")
|
| 228 |
client = OpenAI(
|
|
|
|
| 242 |
print(result)
|
| 243 |
return result
|
| 244 |
|
| 245 |
+
# Using Vertex AI API for Gemini models
|
| 246 |
elif 'gemini' in self.model_id.lower():
|
| 247 |
vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1")
|
| 248 |
model = GenerativeModel(
|
|
|
|
| 251 |
)
|
| 252 |
generation_config = {
|
| 253 |
"temperature": 0,
|
| 254 |
+
"max_output_tokens": 500
|
| 255 |
}
|
| 256 |
safety_settings = [
|
| 257 |
SafetySetting(
|
|
|
|
| 279 |
result = response.text
|
| 280 |
print(result)
|
| 281 |
return result
|
| 282 |
+
|
| 283 |
+
# Using Replicate API
|
| 284 |
elif using_replicate_api:
|
| 285 |
print("using replicate")
|
| 286 |
if 'snowflake' in self.model_id.lower():
|
|
|
|
| 310 |
print(response)
|
| 311 |
return response
|
| 312 |
|
| 313 |
+
# Using Anthropic API for Claude models
|
| 314 |
elif 'claude' in self.model_id.lower(): # using anthropic api
|
| 315 |
print('using Anthropic API')
|
| 316 |
client = anthropic.Anthropic()
|
|
|
|
| 336 |
print(result)
|
| 337 |
return result
|
| 338 |
|
| 339 |
+
# Using Cohere API
|
| 340 |
elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower():
|
| 341 |
co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN'))
|
| 342 |
response = co.chat(
|
|
|
|
| 351 |
print(result)
|
| 352 |
return result
|
| 353 |
|
| 354 |
+
# Using MistralAI API
|
| 355 |
elif 'mistral-large' in self.model_id.lower():
|
| 356 |
api_key = os.environ["MISTRAL_API_KEY"]
|
| 357 |
client = Mistral(api_key=api_key)
|
|
|
|
| 376 |
print(result)
|
| 377 |
return result
|
| 378 |
|
| 379 |
+
# Using Deepseek API
|
| 380 |
elif 'deepseek' in self.model_id.lower():
|
| 381 |
client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com")
|
| 382 |
response = client.chat.completions.create(
|
|
|
|
| 393 |
print(result)
|
| 394 |
return result
|
| 395 |
|
| 396 |
+
# Using HF pipeline or local checkpoints
|
| 397 |
elif self.local_model is None and self.local_pipeline is None:
|
| 398 |
if using_pipeline:
|
| 399 |
self.local_pipeline = pipeline(
|
| 400 |
"text-generation",
|
| 401 |
model=self.model_id,
|
| 402 |
tokenizer=AutoTokenizer.from_pretrained(self.model_id),
|
| 403 |
+
torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() or 'llama-3.3' in self.model_id.lower() else "auto",
|
| 404 |
device_map="auto",
|
| 405 |
trust_remote_code=True
|
| 406 |
)
|
| 407 |
else:
|
| 408 |
if 'ragamuffin' in self.model_id.lower():
|
| 409 |
self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id))
|
| 410 |
+
|
| 411 |
else:
|
| 412 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
| 413 |
print("Tokenizer loaded")
|
|
|
|
| 429 |
# self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id),
|
| 430 |
# torch_dtype=torch.bfloat16, # forcing bfloat16 for now
|
| 431 |
# attn_implementation="flash_attention_2")
|
| 432 |
+
elif 'olmo' in self.model_id.lower():
|
| 433 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id)#torch_dtype="auto"
|
| 434 |
+
|
| 435 |
+
elif 'qwq-' in self.model_id.lower():
|
| 436 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype="auto", device_map="auto")
|
| 437 |
+
|
| 438 |
else:
|
| 439 |
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto"
|
| 440 |
# print(self.local_model.device)
|
|
|
|
| 449 |
]
|
| 450 |
outputs = self.local_pipeline(
|
| 451 |
messages,
|
| 452 |
+
max_new_tokens=256,
|
| 453 |
# return_full_text=False,
|
| 454 |
do_sample=False
|
| 455 |
)
|
|
|
|
| 459 |
|
| 460 |
elif self.local_model: # cannot call API. using local model / pipeline
|
| 461 |
print('Using local model')
|
| 462 |
+
|
| 463 |
+
# Set appropriate prompt based on model document
|
| 464 |
if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower():
|
| 465 |
messages=[
|
| 466 |
# gemma-1.1, mistral-7b does not accept system role
|
|
|
|
| 494 |
{"role": "system", "content": system_prompt},
|
| 495 |
{"role": "user", "content": user_prompt}
|
| 496 |
]
|
| 497 |
+
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 498 |
+
|
| 499 |
+
# Tokenize inputs
|
| 500 |
+
if 'olmo' in self.model_id.lower():
|
| 501 |
+
input_ids = self.tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)#.to(self.device)
|
| 502 |
+
elif 'qwq' in self.model_id.lower():
|
| 503 |
+
input_ids = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 504 |
+
else:
|
| 505 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 506 |
+
|
| 507 |
+
# Generate outputs
|
| 508 |
if 'granite' in self.model_id.lower():
|
| 509 |
self.local_model.eval()
|
| 510 |
outputs = self.local_model.generate(**input_ids, max_new_tokens=250)
|
| 511 |
+
elif 'olmo' in self.model_id.lower():
|
| 512 |
+
outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01)#top_k=50, top_p=0.95)
|
| 513 |
+
elif 'qwq' in self.model_id.lower():
|
| 514 |
+
outputs = self.local_model.generate(**input_ids, max_new_tokens=512, do_sample=True, temperature=0.01)
|
| 515 |
else:
|
| 516 |
with torch.no_grad():
|
| 517 |
outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id
|
| 518 |
if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower():
|
| 519 |
outputs = outputs[:, input_ids['input_ids'].shape[1]:]
|
| 520 |
+
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower() or 'qwq-' in self.model_id.lower():
|
| 521 |
outputs = [
|
| 522 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs)
|
| 523 |
]
|
| 524 |
+
|
| 525 |
+
# Decode outputs
|
| 526 |
if 'qwen2-vl' in self.model_id.lower():
|
| 527 |
result = self.processor.batch_decode(
|
| 528 |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 529 |
)[0]
|
| 530 |
+
elif 'olmo' in self.model_id.lower() or 'qwq' in self.model_id.lower():
|
| 531 |
+
result = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 532 |
else:
|
| 533 |
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 534 |
|
|
|
|
| 540 |
result = result.split(messages[-1]['content'])[1].strip()
|
| 541 |
elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower():
|
| 542 |
pass
|
| 543 |
+
elif 'olmo' in self.model_id.lower():
|
| 544 |
+
result = result.split("<|assistant|>\n")[-1]
|
| 545 |
else:
|
|
|
|
|
|
|
| 546 |
result = result.replace(prompt.strip(), '')
|
| 547 |
|
| 548 |
print(result)
|