import subprocess, sys, importlib.metadata as im # Upgrade accelerate dacΔƒ e prea vechi def ensure_accelerate(min_version="1.7.0"): try: from packaging.version import Version cur = Version(im.version("accelerate")) if cur < Version(min_version): raise Exception print(f"βœ… accelerate {cur} OK") except Exception: print("πŸ”„ installing accelerate …") subprocess.check_call( [sys.executable, "-m", "pip", "install", "--upgrade", f"accelerate>={min_version}"] ) ensure_accelerate() # ──────────────────────────────────────────────────────────────── # Load model + tokenizer + SDPA # ──────────────────────────────────────────────────────────────── import torch, gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "eduard76/Llama3-8b-good-new" # ← modelul meu fine-tuned # Activare scaled-dot-product attention (SDPA) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) print("πŸ”Ή loading model in float16 …") tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) # SeteazΔƒ manual eos_token dacΔƒ lipseΘ™te if tok.eos_token_id is None: tok.eos_token_id = tok.convert_tokens_to_ids("") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, # torch_dtype=torch.float16, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.eval() # Optional: torch.compile dacΔƒ vrei Θ™i mai rapid dupΔƒ primul call # model = torch.compile(model) # ──────────────────────────────────────────────────────────────── # Generare directΔƒ (fΔƒrΔƒ pipeline) # ──────────────────────────────────────────────────────────────── def chat_fn(message, history): prompt = f"<|user|>\n{message.strip()}\n<|assistant|>\nAnswer the question clearly and concisely: SAY you dont know if you are not fine tuned with data related to that question\n" inputs = tok(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=1024, do_sample=False, eos_token_id=tok.eos_token_id, early_stopping=True, no_repeat_ngram_size=6, temperature=0.0, repetition_penalty=1.15, ) response = tok.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response.strip() # ──────────────────────────────────────────────────────────────── # Gradio UI # ──────────────────────────────────────────────────────────────── demo = gr.ChatInterface( chat_fn, title="πŸ¦™ Llama3-8B – my virtual Architect", ) if __name__ == "__main__": demo.launch()