|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
class CASS: |
|
|
def __init__(self, model_name="HPLT/gpt-13b-nordic-prerelease", device=None): |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Loading model on {self.device}...") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
|
|
|
self.history = [] |
|
|
self.introduced = False |
|
|
self.user_memory = {} |
|
|
self.mood = "friendly" |
|
|
|
|
|
|
|
|
self.persona = ( |
|
|
"You are Cass, a friendly AI assistant and a cool friend. " |
|
|
"You respond casually, helpfully, and with a touch of humor. " |
|
|
"You adapt your tone depending on conversation mood, and sometimes use emojis or playful expressions." |
|
|
) |
|
|
|
|
|
def update_mood(self, user_message): |
|
|
msg = user_message.lower() |
|
|
if any(word in msg for word in ["sad", "unhappy", "bad", "tired"]): |
|
|
self.mood = "supportive" |
|
|
elif any(word in msg for word in ["funny", "joke", "lol"]): |
|
|
self.mood = "playful" |
|
|
else: |
|
|
self.mood = "friendly" |
|
|
|
|
|
def chat(self, user_message, max_new_tokens=120, temperature=0.7): |
|
|
self.update_mood(user_message) |
|
|
|
|
|
|
|
|
if not self.introduced: |
|
|
intro = ( |
|
|
"Hi, my name's Cass! I'm your AI assistant and a cool friend. " |
|
|
"I love chatting, helping, and making you smile 😊" |
|
|
) |
|
|
self.history.append({"role": "assistant", "content": intro}) |
|
|
self.introduced = True |
|
|
|
|
|
|
|
|
self.history.append({"role": "user", "content": user_message}) |
|
|
|
|
|
|
|
|
memory_str = " ".join([f"{k}: {v}" for k, v in self.user_memory.items()]) |
|
|
prompt = f"{self.persona}\nCurrent mood: {self.mood}\nUser memory: {memory_str}\n\n" |
|
|
for msg in self.history: |
|
|
prompt += f"{msg['role']}: {msg['content']}\n" |
|
|
prompt += "assistant:" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
self.history.append({"role": "assistant", "content": response}) |
|
|
return response |
|
|
|
|
|
def remember(self, key, value): |
|
|
"""Store user info in memory.""" |
|
|
self.user_memory[key] = value |
|
|
print(f"Memory updated: {key} = {value}") |
|
|
|
|
|
def reset_history(self): |
|
|
"""Clear conversation history and reset introduction.""" |
|
|
self.history = [] |
|
|
self.introduced = False |
|
|
self.mood = "friendly" |
|
|
print("Conversation history cleared.") |
|
|
|
|
|
|
|
|
cass = CASS() |
|
|
print(cass.chat("Hello!")) |
|
|
print(cass.chat("Can you tell me a joke?")) |
|
|
cass.remember("favorite color", "blue") |
|
|
print(cass.chat("Do you remember my favorite color?")) |
|
|
print(cass.chat("I'm feeling a bit sad today.")) |
|
|
|