Momo-336M-sft / momo_chat.py
dill-dev's picture
Update momo_chat.py
8a9e408 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# ── 1. Load Model ────────────────────────────────────────────────
model_id = "dill-dev/Momo-336M-sft"
print("🌸 Loading Momo from Hugging Face...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True, # custom modeling_momo.py load ࢚ࢻࢱ්ࢱ
dtype=torch.float32, # torch_dtype deprecated, use dtype
device_map="cpu",
)
model.eval()
print("βœ… Momo is ready! Type 'exit' or 'quit' to stop.\n")
# ── 2. Prompt Format ─────────────────────────────────────────────
def format_prompt(instruction):
return (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
f"### Instruction:\n{instruction}\n\n### Response:\n"
)
# ── 3. Generate Function ─────────────────────────────────────────
# Momo's custom generate() use ΰΆšΰΆ»ΰΆ±ΰ·€ΰ· β€” HF generate() ΰ·€ΰ·™ΰΆ±ΰ·”ΰ·€ΰΆ§
# ΰΆ’ ࢱිසා do_sample, temperature etc. directly pass ࢚ࢻࢱ්ࢱේ ΰΆ±ΰ·‘
def chat(user_input, max_new_tokens=200, temperature=0.75,
top_k=50, top_p=0.92, rep_penalty=1.1):
prompt = format_prompt(user_input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
rep_penalty=rep_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
# Input prompt හැࢻ response ΰΆšΰ·œΰΆ§ΰ·ƒ ΰ·€ΰ·’ΰΆ­ΰΆ»ΰΆšΰ·Š decode ࢚ࢻࢱ්ࢱ
new_tokens = output_ids[0][input_ids.shape[1]:]
full_output = tokenizer.decode(new_tokens, skip_special_tokens=True)
# "### Response:" marker ΰΆ­ΰ·’ΰΆΆΰ·š ࢱࢸ් ΰΆ’ΰΆ§ ΰΆ΄ΰ·ƒΰ·Šΰ·ƒΰ·š ΰΆšΰ·œΰΆ§ΰ·ƒ ࢜ࢱ්ࢱ
if "### Response:" in full_output:
response = full_output.split("### Response:")[-1].strip()
else:
response = full_output.strip()
return response
# ── 4. Interactive Chat Loop ─────────────────────────────────────
while True:
user_input = input("πŸ§‘ You: ")
if user_input.lower() in ['exit', 'quit']:
print("🌸 Momo: Bye bye! See you soon.")
break
if not user_input.strip():
continue
print("🌸 Momo is thinking...", end="\r")
response = chat(user_input)
print(" " * 35, end="\r")
print(f"🌸 Momo: {response}\n")