File size: 3,583 Bytes
9de80d8
 
 
 
b9a8c57
9de80d8
 
 
 
b9a8c57
 
 
 
 
9de80d8
 
 
 
b9a8c57
 
9de80d8
 
 
 
b9a8c57
9de80d8
 
 
 
b9a8c57
 
9de80d8
b9a8c57
9de80d8
 
 
 
 
 
 
b9a8c57
9de80d8
b9a8c57
9de80d8
 
 
b9a8c57
9de80d8
 
b9a8c57
9de80d8
 
b9a8c57
9de80d8
b9a8c57
 
 
 
9de80d8
b9a8c57
 
9de80d8
 
 
 
b9a8c57
9de80d8
 
 
b9a8c57
9de80d8
 
b9a8c57
9de80d8
 
 
 
b9a8c57
9de80d8
 
 
 
b9a8c57
9de80d8
 
 
 
 
 
 
b9a8c57
9de80d8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

class CASS:
    def __init__(self, model_name="HPLT/gpt-13b-nordic-prerelease", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading model on {self.device}...")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16  # Use FP16 for memory efficiency
        )

        # Conversation memory
        self.history = []
        self.introduced = False
        self.user_memory = {}
        self.mood = "friendly"

        # Cass's persona
        self.persona = (
            "You are Cass, a friendly AI assistant and a cool friend. "
            "You respond casually, helpfully, and with a touch of humor. "
            "You adapt your tone depending on conversation mood, and sometimes use emojis or playful expressions."
        )

    def update_mood(self, user_message):
        msg = user_message.lower()
        if any(word in msg for word in ["sad", "unhappy", "bad", "tired"]):
            self.mood = "supportive"
        elif any(word in msg for word in ["funny", "joke", "lol"]):
            self.mood = "playful"
        else:
            self.mood = "friendly"

    def chat(self, user_message, max_new_tokens=120, temperature=0.7):
        self.update_mood(user_message)

        # Add introduction if first message
        if not self.introduced:
            intro = (
                "Hi, my name's Cass! I'm your AI assistant and a cool friend. "
                "I love chatting, helping, and making you smile 😊"
            )
            self.history.append({"role": "assistant", "content": intro})
            self.introduced = True

        # Add user message
        self.history.append({"role": "user", "content": user_message})

        # Build prompt string from persona + history
        memory_str = " ".join([f"{k}: {v}" for k, v in self.user_memory.items()])
        prompt = f"{self.persona}\nCurrent mood: {self.mood}\nUser memory: {memory_str}\n\n"
        for msg in self.history:
            prompt += f"{msg['role']}: {msg['content']}\n"
        prompt += "assistant:"

        # Tokenize and generate
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            pad_token_id=self.tokenizer.eos_token_id
        )

        # Decode
        response = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)

        # Save assistant message
        self.history.append({"role": "assistant", "content": response})
        return response

    def remember(self, key, value):
        """Store user info in memory."""
        self.user_memory[key] = value
        print(f"Memory updated: {key} = {value}")

    def reset_history(self):
        """Clear conversation history and reset introduction."""
        self.history = []
        self.introduced = False
        self.mood = "friendly"
        print("Conversation history cleared.")

# Example usage
cass = CASS()
print(cass.chat("Hello!"))  # Cass introduces itself
print(cass.chat("Can you tell me a joke?"))
cass.remember("favorite color", "blue")
print(cass.chat("Do you remember my favorite color?"))
print(cass.chat("I'm feeling a bit sad today."))