test / app.py
Mohansai2004's picture
feat: implement DeepSeek Janus chat interface
96a17a6
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
# Configure page
st.set_page_config(
page_title="DeepSeek Assistant",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
# Set up logging and style
logging.basicConfig(level=logging.INFO)
st.markdown("""
<style>
.stChat { padding: 20px; border-radius: 10px; }
.user-message { background-color: #e6f3ff; }
.assistant-message { background-color: #f0f2f6; }
.stButton button { background-color: #2E86C1; }
</style>
""", unsafe_allow_html=True)
st.title("🧠 DeepSeek AI Assistant")
if "model_loaded" not in st.session_state:
st.session_state.model_loaded = False
st.markdown("""
👈 Select 'Chat' from the sidebar to start chatting!
### Features:
- Real-time response generation
- Context-aware conversations
- Professional responses
- Memory efficient
### Tips:
- Be specific in your questions
- Use clear language
- Start with simple queries
""")
@st.cache_resource
def load_model():
model_name = "deepseek-ai/Janus-Pro-7B"
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
padding_side='left'
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map='cpu'
)
model.eval()
torch.set_num_threads(8)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
def generate_response(prompt, model, tokenizer):
try:
# Janus-Pro specific prompt format
chat_prompt = f"""### Human: {prompt}
### Assistant: Let me help you with that."""
inputs = tokenizer(
chat_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Create placeholder for streaming output
message_placeholder = st.empty()
full_response = ""
with torch.inference_mode():
generated_ids = []
for _ in range(512): # Max new tokens
# Generate next token
outputs = model.generate(
inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
max_new_tokens=1,
temperature=0.7,
do_sample=True,
top_p=0.95,
top_k=50, # Added for better quality
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
next_token = outputs[0][-1].item()
generated_ids.append(next_token)
# Decode and display current state
current_output = tokenizer.decode(generated_ids, skip_special_tokens=True)
full_response = current_output
message_placeholder.markdown(full_response)
# Check for end of generation
if next_token == tokenizer.eos_token_id or "### Human:" in full_response:
break
# Clean up response for Janus format
response = full_response.split("### Assistant:")[-1].strip()
response = response.split("### Human:")[0].strip()
return response
except Exception as e:
st.error(f"Error: {str(e)}")
return None
def init_chat():
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.model, st.session_state.tokenizer = load_model()
def main():
st.title("🧠 DeepSeek R1 Chat Assistant")
init_chat()
with st.sidebar:
st.markdown("### Chat Settings")
if st.button("🗑️ Clear History", use_container_width=True):
st.session_state.messages = []
st.rerun()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask me anything..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
context = "\n".join([
f"{m['role']}: {m['content']}"
for m in st.session_state.messages[-3:]
])
response = generate_response(
context,
st.session_state.model,
st.session_state.tokenizer
)
if response:
st.markdown(response)
st.session_state.messages.append(
{"role": "assistant", "content": response}
)
if __name__ == "__main__":
main()