Advait3009 commited on
Commit
298895c
·
verified ·
1 Parent(s): d06533b

Create memory.py

Browse files
Files changed (1) hide show
  1. utils/memory.py +41 -0
utils/memory.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from typing import List, Dict, Union
3
+ from PIL import Image
4
+
5
+ class ChatMemory:
6
+ def __init__(self, max_context_length: int = 2048):
7
+ self.max_context_length = max_context_length
8
+ if "memory" not in st.session_state:
9
+ st.session_state.memory = {
10
+ 'history': [],
11
+ 'context': []
12
+ }
13
+
14
+ def update(self, user_input: Union[str, Image.Image], response: str):
15
+ """Store interaction with automatic context pruning"""
16
+ # Store user input
17
+ if isinstance(user_input, Image.Image):
18
+ st.session_state.memory['history'].append(('user', 'image', user_input))
19
+ else:
20
+ st.session_state.memory['history'].append(('user', 'text', user_input))
21
+
22
+ # Store assistant response
23
+ st.session_state.memory['history'].append(('assistant', 'text', response))
24
+
25
+ # Maintain context window
26
+ current_length = sum(len(item[2]) for item in st.session_state.memory['history'] if item[1] == 'text')
27
+ while current_length > self.max_context_length and len(st.session_state.memory['history']) > 2:
28
+ removed = st.session_state.memory['history'].pop(0)
29
+ if removed[1] == 'text':
30
+ current_length -= len(removed[2])
31
+
32
+ def get_context(self) -> str:
33
+ """Generate conversation context string"""
34
+ return "\n".join(
35
+ f"{role}: {content}"
36
+ for role, type_, content in st.session_state.memory['history']
37
+ if type_ == 'text'
38
+ )
39
+
40
+ def clear(self):
41
+ st.session_state.memory = {'history': [], 'context': []}