Demo4Tony / app.py
eduard76's picture
Update app.py
db8ea25 verified
import subprocess, sys, importlib.metadata as im
# Upgrade accelerate dacă e prea vechi
def ensure_accelerate(min_version="1.7.0"):
try:
from packaging.version import Version
cur = Version(im.version("accelerate"))
if cur < Version(min_version):
raise Exception
print(f"βœ… accelerate {cur} OK")
except Exception:
print("πŸ”„ installing accelerate …")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "--upgrade",
f"accelerate>={min_version}"]
)
ensure_accelerate()
# ────────────────────────────────────────────────────────────────
# Load model + tokenizer + SDPA
# ────────────────────────────────────────────────────────────────
import torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "eduard76/Llama3-8b-good-new" # ← modelul meu fine-tuned
# Activare scaled-dot-product attention (SDPA)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
print("πŸ”Ή loading model in float16 …")
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# Setează manual eos_token dacă lipsește
if tok.eos_token_id is None:
tok.eos_token_id = tok.convert_tokens_to_ids("</s>")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
# torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
# Optional: torch.compile dacă vrei și mai rapid după primul call
# model = torch.compile(model)
# ────────────────────────────────────────────────────────────────
# Generare directă (fără pipeline)
# ────────────────────────────────────────────────────────────────
def chat_fn(message, history):
prompt = f"<|user|>\n{message.strip()}\n<|assistant|>\nAnswer the question clearly and concisely: SAY you dont know if you are not fine tuned with data related to that question\n"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False,
eos_token_id=tok.eos_token_id,
early_stopping=True,
no_repeat_ngram_size=6,
temperature=0.0,
repetition_penalty=1.15,
)
response = tok.decode(output[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True)
return response.strip()
# ────────────────────────────────────────────────────────────────
# Gradio UI
# ────────────────────────────────────────────────────────────────
demo = gr.ChatInterface(
chat_fn,
title="πŸ¦™ Llama3-8B – my virtual Architect",
)
if __name__ == "__main__":
demo.launch()