finance-slm-1m / chat.py
sathishphdai's picture
Upload Finance-SLM v2
1e7dc59 verified
#!/usr/bin/env python3
"""Interactive chat and demo inference for Finance-SLM."""
import torch
from tokenizers import Tokenizer
from config import cfg
from model import IndustrySLM
def load_model(checkpoint_name="best_model.pt"):
device = torch.device(cfg.device)
ckpt_path = cfg.checkpoint_dir / checkpoint_name
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
for key, val in ckpt.get("config", {}).items():
if hasattr(cfg, key): setattr(cfg, key, val)
model = IndustrySLM()
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model = model.to(device)
model.eval()
tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename
tokenizer = Tokenizer.from_file(str(tok_path))
print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params")
return model, tokenizer, device
def generate_response(model, tokenizer, device, prompt, max_tokens=None,
temperature=0.8, top_k=50, top_p=0.9):
max_tokens = max_tokens or cfg.max_new_tokens
encoded = tokenizer.encode(prompt)
ids = encoded.ids
if ids and ids[-1] == 3: ids = ids[:-1]
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
input_len = input_ids.shape[1]
with torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=max_tokens,
temperature=temperature, top_k=top_k, top_p=top_p)
new_tokens = output_ids[0][input_len:].tolist()
response = tokenizer.decode(new_tokens)
return response.replace("<eos>", "").replace("<bos>", "").replace("<pad>", "").strip()
DEMO_PROMPTS = [
"The stock market is driven by",
"Risk management in banking involves",
"Blockchain technology enables financial institutions to",
"Modern portfolio theory suggests that",
"Central banks influence the economy by",
]
def demo_generation(model, tokenizer, device):
print(f"\n{'='*60}\nDemo: {cfg.domain_name}-SLM Inference\n{'='*60}\n")
for i, prompt in enumerate(DEMO_PROMPTS, 1):
print(f"[{i}] Prompt: {prompt}")
response = generate_response(model, tokenizer, device, prompt, max_tokens=256)
print(f" Response: {response}\n")
def interactive_chat():
print("Loading model...")
model, tokenizer, device = load_model()
print(f"\n{'='*60}\n{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit)\n{'='*60}\n")
while True:
try:
user_input = input("You: ").strip()
if not user_input: continue
if user_input.lower() == "quit": break
if user_input.lower() == "demo":
demo_generation(model, tokenizer, device); continue
response = generate_response(model, tokenizer, device, user_input)
print(f"{cfg.domain_name}-SLM: {response}\n")
except KeyboardInterrupt: break
print("\nGoodbye!")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "demo":
model, tokenizer, device = load_model()
demo_generation(model, tokenizer, device)
else:
interactive_chat()