linvest21's picture
download
raw
5.62 kB
from __future__ import annotations
import argparse
from pathlib import Path
import sys
import time
DEFAULT_BASE_MODEL = "meta-llama/Meta-Llama-3-8B"
DEFAULT_ADAPTER = (
Path(__file__).resolve().parents[2]
/ "implementation"
/ "runs"
/ "run_step19_repair_train_001"
/ "remote_artifacts"
/ "adapter"
)
def status(message: str) -> None:
timestamp = time.strftime("%H:%M:%S")
print(f"[SHFT local chat {timestamp}] {message}", flush=True)
def import_runtime():
try:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError as exc:
missing = exc.name or str(exc)
raise SystemExit(
f"Missing local inference dependency: {missing}\n"
"Install/repair dependencies from impl_codex/self_healing_finetuning first."
) from exc
return torch, PeftModel, AutoModelForCausalLM, AutoTokenizer
def choose_dtype(torch, requested: str):
if requested == "float32":
return torch.float32
if requested == "float16":
return torch.float16
if requested == "bfloat16":
return torch.bfloat16
if torch.cuda.is_available():
return torch.float16
return torch.float32
def build_prompt(tokenizer, messages: list[dict[str, str]]) -> str:
if getattr(tokenizer, "chat_template", None):
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
rendered = []
for item in messages:
rendered.append(f"{item['role'].capitalize()}: {item['content']}")
rendered.append("Assistant:")
return "\n".join(rendered)
def parse_args(argv: list[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Interactive local chat for Linvest21 SHFT adapters.")
parser.add_argument("--base-model", default=DEFAULT_BASE_MODEL)
parser.add_argument("--adapter-dir", default=str(DEFAULT_ADAPTER))
parser.add_argument("--max-new-tokens", type=int, default=80)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--dtype", choices=["auto", "float32", "float16", "bfloat16"], default="auto")
parser.add_argument("--system-prompt", default="You are Linvest21_FinGPT, a careful financial analysis assistant. Give concise, factual, numerate answers and state uncertainty when data is incomplete.")
return parser.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv or sys.argv[1:])
adapter_dir = Path(args.adapter_dir).resolve()
if not adapter_dir.exists():
raise SystemExit(f"Adapter directory does not exist: {adapter_dir}")
if not (adapter_dir / "adapter_model.safetensors").exists():
raise SystemExit(f"Adapter weights not found in: {adapter_dir}")
torch, PeftModel, AutoModelForCausalLM, AutoTokenizer = import_runtime()
dtype = choose_dtype(torch, args.dtype)
device = "cuda" if torch.cuda.is_available() else "cpu"
status(f"base_model={args.base_model}")
status(f"adapter_dir={adapter_dir}")
status(f"device={device} dtype={str(dtype).replace('torch.', '')}")
if device == "cpu":
status("CUDA is not available in the current Python torch install; first response may be slow.")
status("loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
status("loading base model; this may download gated model weights the first time")
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
status("loading Linvest21 trained adapter")
model = PeftModel.from_pretrained(model, str(adapter_dir))
model.eval()
model.to(device)
model.generation_config.max_length = None
status("ready. Type /exit to quit, /clear to reset context.")
messages: list[dict[str, str]] = [{"role": "system", "content": args.system_prompt}]
while True:
try:
user_text = input("\nlinvest21> ").strip()
except (EOFError, KeyboardInterrupt):
print()
return 0
if not user_text:
continue
if user_text.lower() in {"/exit", "exit", "quit", "/quit"}:
return 0
if user_text.lower() == "/clear":
messages = [{"role": "system", "content": args.system_prompt}]
status("context cleared")
continue
messages.append({"role": "user", "content": user_text})
prompt = build_prompt(tokenizer, messages)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_len = inputs["input_ids"].shape[-1]
status("generating")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=args.temperature > 0,
temperature=args.temperature,
top_p=args.top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
answer = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
if not answer:
answer = "[empty response]"
print(f"\nLinvest21_FinGPT: {answer}")
messages.append({"role": "assistant", "content": answer})
if __name__ == "__main__":
raise SystemExit(main())

Xet Storage Details

Size:
5.62 kB
·
Xet hash:
c85df5c1b53bbbb1a3d90914432fa0e9528b5c23f6b93568b883c1c2f6c78ada

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.