Buckets:
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import sys | |
| import time | |
| DEFAULT_BASE_MODEL = "meta-llama/Meta-Llama-3-8B" | |
| DEFAULT_ADAPTER = ( | |
| Path(__file__).resolve().parents[2] | |
| / "implementation" | |
| / "runs" | |
| / "run_step19_repair_train_001" | |
| / "remote_artifacts" | |
| / "adapter" | |
| ) | |
| def status(message: str) -> None: | |
| timestamp = time.strftime("%H:%M:%S") | |
| print(f"[SHFT local chat {timestamp}] {message}", flush=True) | |
| def import_runtime(): | |
| try: | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError as exc: | |
| missing = exc.name or str(exc) | |
| raise SystemExit( | |
| f"Missing local inference dependency: {missing}\n" | |
| "Install/repair dependencies from impl_codex/self_healing_finetuning first." | |
| ) from exc | |
| return torch, PeftModel, AutoModelForCausalLM, AutoTokenizer | |
| def choose_dtype(torch, requested: str): | |
| if requested == "float32": | |
| return torch.float32 | |
| if requested == "float16": | |
| return torch.float16 | |
| if requested == "bfloat16": | |
| return torch.bfloat16 | |
| if torch.cuda.is_available(): | |
| return torch.float16 | |
| return torch.float32 | |
| def build_prompt(tokenizer, messages: list[dict[str, str]]) -> str: | |
| if getattr(tokenizer, "chat_template", None): | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| rendered = [] | |
| for item in messages: | |
| rendered.append(f"{item['role'].capitalize()}: {item['content']}") | |
| rendered.append("Assistant:") | |
| return "\n".join(rendered) | |
| def parse_args(argv: list[str]) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Interactive local chat for Linvest21 SHFT adapters.") | |
| parser.add_argument("--base-model", default=DEFAULT_BASE_MODEL) | |
| parser.add_argument("--adapter-dir", default=str(DEFAULT_ADAPTER)) | |
| parser.add_argument("--max-new-tokens", type=int, default=80) | |
| parser.add_argument("--temperature", type=float, default=0.2) | |
| parser.add_argument("--top-p", type=float, default=0.9) | |
| parser.add_argument("--dtype", choices=["auto", "float32", "float16", "bfloat16"], default="auto") | |
| parser.add_argument("--system-prompt", default="You are Linvest21_FinGPT, a careful financial analysis assistant. Give concise, factual, numerate answers and state uncertainty when data is incomplete.") | |
| return parser.parse_args(argv) | |
| def main(argv: list[str] | None = None) -> int: | |
| args = parse_args(argv or sys.argv[1:]) | |
| adapter_dir = Path(args.adapter_dir).resolve() | |
| if not adapter_dir.exists(): | |
| raise SystemExit(f"Adapter directory does not exist: {adapter_dir}") | |
| if not (adapter_dir / "adapter_model.safetensors").exists(): | |
| raise SystemExit(f"Adapter weights not found in: {adapter_dir}") | |
| torch, PeftModel, AutoModelForCausalLM, AutoTokenizer = import_runtime() | |
| dtype = choose_dtype(torch, args.dtype) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| status(f"base_model={args.base_model}") | |
| status(f"adapter_dir={adapter_dir}") | |
| status(f"device={device} dtype={str(dtype).replace('torch.', '')}") | |
| if device == "cpu": | |
| status("CUDA is not available in the current Python torch install; first response may be slow.") | |
| status("loading tokenizer") | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| status("loading base model; this may download gated model weights the first time") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.base_model, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| status("loading Linvest21 trained adapter") | |
| model = PeftModel.from_pretrained(model, str(adapter_dir)) | |
| model.eval() | |
| model.to(device) | |
| model.generation_config.max_length = None | |
| status("ready. Type /exit to quit, /clear to reset context.") | |
| messages: list[dict[str, str]] = [{"role": "system", "content": args.system_prompt}] | |
| while True: | |
| try: | |
| user_text = input("\nlinvest21> ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| return 0 | |
| if not user_text: | |
| continue | |
| if user_text.lower() in {"/exit", "exit", "quit", "/quit"}: | |
| return 0 | |
| if user_text.lower() == "/clear": | |
| messages = [{"role": "system", "content": args.system_prompt}] | |
| status("context cleared") | |
| continue | |
| messages.append({"role": "user", "content": user_text}) | |
| prompt = build_prompt(tokenizer, messages) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| status("generating") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=args.max_new_tokens, | |
| do_sample=args.temperature > 0, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| answer = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip() | |
| if not answer: | |
| answer = "[empty response]" | |
| print(f"\nLinvest21_FinGPT: {answer}") | |
| messages.append({"role": "assistant", "content": answer}) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |
Xet Storage Details
- Size:
- 5.62 kB
- Xet hash:
- c85df5c1b53bbbb1a3d90914432fa0e9528b5c23f6b93568b883c1c2f6c78ada
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.