|
|
import subprocess, sys, importlib.metadata as im |
|
|
|
|
|
|
|
|
def ensure_accelerate(min_version="1.7.0"): |
|
|
try: |
|
|
from packaging.version import Version |
|
|
cur = Version(im.version("accelerate")) |
|
|
if cur < Version(min_version): |
|
|
raise Exception |
|
|
print(f"β
accelerate {cur} OK") |
|
|
except Exception: |
|
|
print("π installing accelerate β¦") |
|
|
subprocess.check_call( |
|
|
[sys.executable, "-m", "pip", "install", "--upgrade", |
|
|
f"accelerate>={min_version}"] |
|
|
) |
|
|
|
|
|
ensure_accelerate() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch, gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
MODEL_ID = "eduard76/Llama3-8b-good-new" |
|
|
|
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(True) |
|
|
torch.backends.cuda.enable_math_sdp(True) |
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
|
|
|
print("πΉ loading model in float16 β¦") |
|
|
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) |
|
|
|
|
|
|
|
|
if tok.eos_token_id is None: |
|
|
tok.eos_token_id = tok.convert_tokens_to_ids("</s>") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
|
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_fn(message, history): |
|
|
prompt = f"<|user|>\n{message.strip()}\n<|assistant|>\nAnswer the question clearly and concisely: SAY you dont know if you are not fine tuned with data related to that question\n" |
|
|
inputs = tok(prompt, return_tensors="pt").to(model.device) |
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=1024, |
|
|
do_sample=False, |
|
|
eos_token_id=tok.eos_token_id, |
|
|
early_stopping=True, |
|
|
no_repeat_ngram_size=6, |
|
|
temperature=0.0, |
|
|
repetition_penalty=1.15, |
|
|
) |
|
|
response = tok.decode(output[0][inputs.input_ids.shape[1]:], |
|
|
skip_special_tokens=True) |
|
|
return response.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
chat_fn, |
|
|
title="π¦ Llama3-8B β my virtual Architect", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|