File size: 5,267 Bytes
56432ce
9b556bc
6dd794e
c739cf0
 
fe98a76
 
cc807a2
fe98a76
 
 
 
 
 
9b556bc
c739cf0
53e6fb8
 
9b556bc
 
 
fe98a76
53e6fb8
 
 
cc807a2
96a17a6
 
 
 
cc807a2
96a17a6
 
 
 
 
 
 
cc807a2
96a17a6
 
 
 
cc807a2
 
6dd794e
 
cc807a2
9b556bc
6e0086a
36d5a1d
9b556bc
fe98a76
 
36d5a1d
9b556bc
6e0086a
 
9b556bc
6e0086a
36d5a1d
fe98a76
 
 
6e0086a
fe98a76
 
6e0086a
 
 
 
 
5f116e0
9b556bc
04da277
cc807a2
 
 
 
fe98a76
 
 
 
 
 
 
 
53e6fb8
fe98a76
9b556bc
fe98a76
53e6fb8
36d5a1d
fe98a76
cc807a2
fe98a76
 
 
 
 
 
 
cc807a2
fe98a76
 
 
 
 
 
 
 
 
 
 
 
 
cc807a2
fe98a76
04da277
cc807a2
 
 
 
04acc5c
 
9b556bc
 
5f116e0
9b556bc
 
 
 
0222994
53e6fb8
fe98a76
9b556bc
53e6fb8
9b556bc
fe98a76
 
9b556bc
 
 
 
 
 
53e6fb8
fe98a76
9b556bc
 
 
 
 
 
 
 
 
53e6fb8
9b556bc
 
 
 
53e6fb8
 
9b556bc
fe98a76
9b556bc
 
 
5f116e0
53e6fb8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging

# Configure page
st.set_page_config(
    page_title="DeepSeek Assistant",
    page_icon="🧠",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Set up logging and style
logging.basicConfig(level=logging.INFO)

st.markdown("""
<style>
    .stChat { padding: 20px; border-radius: 10px; }
    .user-message { background-color: #e6f3ff; }
    .assistant-message { background-color: #f0f2f6; }
    .stButton button { background-color: #2E86C1; }
</style>
""", unsafe_allow_html=True)

st.title("🧠 DeepSeek AI Assistant")

if "model_loaded" not in st.session_state:
    st.session_state.model_loaded = False

st.markdown("""
👈 Select 'Chat' from the sidebar to start chatting!

### Features:
- Real-time response generation
- Context-aware conversations
- Professional responses
- Memory efficient

### Tips:
- Be specific in your questions
- Use clear language
- Start with simple queries
""")

@st.cache_resource
def load_model():
    model_name = "deepseek-ai/Janus-Pro-7B"
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side='left'
        )
        tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            device_map='cpu'
        )
        
        model.eval()
        torch.set_num_threads(8)
        return model, tokenizer
        
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        st.stop()

def generate_response(prompt, model, tokenizer):
    try:
        # Janus-Pro specific prompt format
        chat_prompt = f"""### Human: {prompt}

### Assistant: Let me help you with that."""
        
        inputs = tokenizer(
            chat_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        # Create placeholder for streaming output
        message_placeholder = st.empty()
        full_response = ""
        
        with torch.inference_mode():
            generated_ids = []
            for _ in range(512):  # Max new tokens
                # Generate next token
                outputs = model.generate(
                    inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
                    max_new_tokens=1,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.95,
                    top_k=50,  # Added for better quality
                    repetition_penalty=1.1,
                    pad_token_id=tokenizer.eos_token_id
                )
                
                next_token = outputs[0][-1].item()
                generated_ids.append(next_token)
                
                # Decode and display current state
                current_output = tokenizer.decode(generated_ids, skip_special_tokens=True)
                full_response = current_output
                message_placeholder.markdown(full_response)
                
                # Check for end of generation
                if next_token == tokenizer.eos_token_id or "### Human:" in full_response:
                    break
            
            # Clean up response for Janus format
            response = full_response.split("### Assistant:")[-1].strip()
            response = response.split("### Human:")[0].strip()
            return response
            
    except Exception as e:
        st.error(f"Error: {str(e)}")
        return None

def init_chat():
    if "messages" not in st.session_state:
        st.session_state.messages = []
        st.session_state.model, st.session_state.tokenizer = load_model()

def main():
    st.title("🧠 DeepSeek R1 Chat Assistant")
    init_chat()
    
    with st.sidebar:
        st.markdown("### Chat Settings")
        if st.button("🗑️ Clear History", use_container_width=True):
            st.session_state.messages = []
            st.rerun()
    
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    if prompt := st.chat_input("Ask me anything..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        
        with st.chat_message("assistant"):
            context = "\n".join([
                f"{m['role']}: {m['content']}" 
                for m in st.session_state.messages[-3:]
            ])
            
            response = generate_response(
                context,
                st.session_state.model,
                st.session_state.tokenizer
            )
            
            if response:
                st.markdown(response)
                st.session_state.messages.append(
                    {"role": "assistant", "content": response}
                )

if __name__ == "__main__":
    main()