Mohansai2004 commited on
Commit
cc807a2
·
1 Parent(s): fe98a76

feat: implement DeepSeek Janus chat interface

Browse files

- Add streaming response generation
- Implement chat UI with real-time updates
- Add multi-page structure
- Update dependencies
- Improve error handling

Files changed (5) hide show
  1. README.md +10 -10
  2. app.py +21 -11
  3. pages/01_chat.py +79 -0
  4. requirements.txt +29 -3
  5. utils.py +130 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: DeepSeek R1 Chat
3
  emoji: 🧠
4
  colorFrom: blue
5
  colorTo: purple
@@ -7,21 +7,21 @@ sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: Advanced Chat using DeepSeek-R1-Distill-8B
11
  ---
12
 
13
- # DeepSeek R1 Chat Assistant
14
 
15
- Powerful chat interface powered by DeepSeek-R1-Distill-Llama-8B model.
16
 
17
  ## Features
18
- - Advanced language understanding
19
- - Context-aware responses
20
- - Efficient 8B parameter model
21
  - Local CPU inference
22
  - Memory optimized
23
 
24
  ## Usage
25
- - Type your message and press Enter
26
- - Clear chat history using sidebar button
27
- - Best for complex conversations
 
1
  ---
2
+ title: Janus Pro Chat
3
  emoji: 🧠
4
  colorFrom: blue
5
  colorTo: purple
 
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Advanced Chat using Janus-Pro-7B
11
  ---
12
 
13
+ # Janus Pro Chat Assistant
14
 
15
+ Professional chat interface powered by DeepSeek Janus-Pro-7B model.
16
 
17
  ## Features
18
+ - Professional-grade responses
19
+ - Real-time streaming
20
+ - Context-aware chat
21
  - Local CPU inference
22
  - Memory optimized
23
 
24
  ## Usage
25
+ - Get detailed, accurate responses
26
+ - Perfect for professional and technical discussions
27
+ - Maintains conversation context
app.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
 
6
  # Configure page
7
  st.set_page_config(
8
- page_title="DeepSeek R1 Chat",
9
  page_icon="🧠",
10
  layout="wide",
11
  initial_sidebar_state="expanded"
@@ -23,9 +23,17 @@ st.markdown("""
23
  </style>
24
  """, unsafe_allow_html=True)
25
 
 
 
 
 
 
 
 
 
26
  @st.cache_resource
27
  def load_model():
28
- model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
29
 
30
  try:
31
  tokenizer = AutoTokenizer.from_pretrained(
@@ -53,10 +61,10 @@ def load_model():
53
 
54
  def generate_response(prompt, model, tokenizer):
55
  try:
56
- chat_prompt = f"""user
57
- {prompt}
58
- assistant
59
- I'll help you with that."""
60
 
61
  inputs = tokenizer(
62
  chat_prompt,
@@ -72,7 +80,7 @@ I'll help you with that."""
72
 
73
  with torch.inference_mode():
74
  generated_ids = []
75
- for i in range(512): # Max new tokens
76
  # Generate next token
77
  outputs = model.generate(
78
  inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
@@ -80,6 +88,7 @@ I'll help you with that."""
80
  temperature=0.7,
81
  do_sample=True,
82
  top_p=0.95,
 
83
  repetition_penalty=1.1,
84
  pad_token_id=tokenizer.eos_token_id
85
  )
@@ -93,12 +102,13 @@ I'll help you with that."""
93
  message_placeholder.markdown(full_response)
94
 
95
  # Check for end of generation
96
- if next_token == tokenizer.eos_token_id:
97
  break
98
 
99
- # Clean up response
100
- response = full_response.split("assistant")[-1].strip()
101
- return response.split("user")[0].strip()
 
102
 
103
  except Exception as e:
104
  st.error(f"Error: {str(e)}")
 
5
 
6
  # Configure page
7
  st.set_page_config(
8
+ page_title="DeepSeek Assistant",
9
  page_icon="🧠",
10
  layout="wide",
11
  initial_sidebar_state="expanded"
 
23
  </style>
24
  """, unsafe_allow_html=True)
25
 
26
+ st.title("🧠 DeepSeek AI Assistant")
27
+ st.markdown("""
28
+ Welcome! Choose a feature from the sidebar to get started.
29
+
30
+ - 💭 Chat Interface: Have a conversation with the AI
31
+ - More features coming soon...
32
+ """)
33
+
34
  @st.cache_resource
35
  def load_model():
36
+ model_name = "deepseek-ai/Janus-Pro-7B"
37
 
38
  try:
39
  tokenizer = AutoTokenizer.from_pretrained(
 
61
 
62
  def generate_response(prompt, model, tokenizer):
63
  try:
64
+ # Janus-Pro specific prompt format
65
+ chat_prompt = f"""### Human: {prompt}
66
+
67
+ ### Assistant: Let me help you with that."""
68
 
69
  inputs = tokenizer(
70
  chat_prompt,
 
80
 
81
  with torch.inference_mode():
82
  generated_ids = []
83
+ for _ in range(512): # Max new tokens
84
  # Generate next token
85
  outputs = model.generate(
86
  inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
 
88
  temperature=0.7,
89
  do_sample=True,
90
  top_p=0.95,
91
+ top_k=50, # Added for better quality
92
  repetition_penalty=1.1,
93
  pad_token_id=tokenizer.eos_token_id
94
  )
 
102
  message_placeholder.markdown(full_response)
103
 
104
  # Check for end of generation
105
+ if next_token == tokenizer.eos_token_id or "### Human:" in full_response:
106
  break
107
 
108
+ # Clean up response for Janus format
109
+ response = full_response.split("### Assistant:")[-1].strip()
110
+ response = response.split("### Human:")[0].strip()
111
+ return response
112
 
113
  except Exception as e:
114
  st.error(f"Error: {str(e)}")
pages/01_chat.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import load_model, generate_stream
3
+ import time
4
+
5
+ def init_chat():
6
+ if "messages" not in st.session_state:
7
+ st.session_state.messages = []
8
+ st.session_state.model, st.session_state.tokenizer = load_model()
9
+ st.session_state.generating = False
10
+
11
+ # Chat interface with error handling
12
+ try:
13
+ st.title("💭 Chat Interface")
14
+ init_chat()
15
+
16
+ # Sidebar controls with session management
17
+ with st.sidebar:
18
+ st.markdown("### Chat Controls")
19
+ cols = st.columns(2)
20
+ with cols[0]:
21
+ if st.button("🗑️ Clear Chat", use_container_width=True):
22
+ st.session_state.messages = []
23
+ st.rerun()
24
+ with cols[1]:
25
+ if st.button("🔄 Reset Model", use_container_width=True):
26
+ st.cache_resource.clear()
27
+ st.rerun()
28
+
29
+ # Chat history with proper formatting
30
+ chat_container = st.container()
31
+ with chat_container:
32
+ for msg in st.session_state.messages:
33
+ with st.chat_message(msg["role"]):
34
+ st.markdown(msg["content"])
35
+
36
+ # Input handling with safeguards
37
+ if prompt := st.chat_input(
38
+ "Ask me anything...",
39
+ disabled=st.session_state.get("generating", False)
40
+ ):
41
+ if not st.session_state.generating:
42
+ st.session_state.generating = True
43
+
44
+ # Show user message
45
+ st.session_state.messages.append({"role": "user", "content": prompt})
46
+ with st.chat_message("user"):
47
+ st.markdown(prompt)
48
+
49
+ # Generate and show response
50
+ with st.chat_message("assistant"):
51
+ try:
52
+ # Get recent context
53
+ context = "\n".join([
54
+ f"{m['role']}: {m['content']}"
55
+ for m in st.session_state.messages[-3:]
56
+ ])
57
+
58
+ response = generate_stream(
59
+ context,
60
+ st.session_state.model,
61
+ st.session_state.tokenizer
62
+ )
63
+
64
+ if response:
65
+ st.session_state.messages.append({
66
+ "role": "assistant",
67
+ "content": response
68
+ })
69
+
70
+ except Exception as e:
71
+ st.error("Failed to generate response. Please try again.")
72
+ st.error(f"Error details: {str(e)}")
73
+
74
+ finally:
75
+ st.session_state.generating = False
76
+
77
+ except Exception as e:
78
+ st.error(f"Application error: {str(e)}")
79
+ st.button("🔄 Restart App")
requirements.txt CHANGED
@@ -1,7 +1,33 @@
1
  # Core dependencies
2
  streamlit>=1.41.1
 
 
 
3
  torch>=2.0.0
4
- transformers>=4.31.0
5
- accelerate>=0.21.0
6
  sentencepiece>=0.1.99
7
- einops>=0.6.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Core dependencies
2
  streamlit>=1.41.1
3
+ watchdog>=3.0.0
4
+
5
+ # Machine Learning and Models
6
  torch>=2.0.0
7
+ transformers>=4.35.0
8
+ accelerate>=0.25.0
9
  sentencepiece>=0.1.99
10
+ einops>=0.7.0
11
+ bitsandbytes>=0.41.1
12
+ safetensors>=0.4.0
13
+
14
+ # Performance optimization
15
+ numpy>=1.24.0
16
+ scipy>=1.11.0
17
+ psutil>=5.9.0
18
+ typing-extensions>=4.8.0
19
+
20
+ # Text processing
21
+ tiktoken>=0.5.1
22
+ regex>=2023.10.3
23
+ tokenizers>=0.15.0
24
+
25
+ # UI enhancements
26
+ streamlit-chat>=0.1.1
27
+ streamlit-option-menu>=0.3.2
28
+ streamlit-extras>=0.3.4
29
+ markdown>=3.5.1
30
+
31
+ # Monitoring and logging
32
+ tqdm>=4.66.1
33
+ rich>=13.7.0
utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import logging
5
+ from typing import Generator, Optional
6
+ import time
7
+
8
+ logging.basicConfig(level=logging.INFO)
9
+
10
+ @st.cache_resource
11
+ def load_model():
12
+ model_name = "deepseek-ai/Janus-Pro-7B"
13
+
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ model_name,
17
+ trust_remote_code=True,
18
+ padding_side='left'
19
+ )
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ torch_dtype=torch.float32,
25
+ low_cpu_mem_usage=True,
26
+ trust_remote_code=True,
27
+ device_map='cpu'
28
+ )
29
+
30
+ model.eval()
31
+ torch.set_num_threads(8)
32
+ return model, tokenizer
33
+
34
+ except Exception as e:
35
+ st.error(f"Error loading model: {str(e)}")
36
+ st.stop()
37
+
38
+ def stream_tokens(response: str, delay: float = 0.01) -> Generator[str, None, None]:
39
+ """Stream tokens with controlled delay for smooth output"""
40
+ buffer = ""
41
+ for char in response:
42
+ buffer += char
43
+ if len(buffer) >= 3 or char in '.!?': # Stream by chunks or punctuation
44
+ yield buffer
45
+ buffer = ""
46
+ time.sleep(delay)
47
+ if buffer: # Yield remaining text
48
+ yield buffer
49
+
50
+ def generate_stream(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Optional[str]:
51
+ try:
52
+ # Format prompt with safety checks
53
+ safe_prompt = prompt.strip().replace("<", "&lt;").replace(">", "&gt;")
54
+ chat_prompt = f"""### Human: {safe_prompt}
55
+
56
+ ### Assistant: I'll help you with that."""
57
+
58
+ # Create persistent placeholder
59
+ message_placeholder = st.empty()
60
+ response_container = st.container()
61
+
62
+ with torch.inference_mode(), st.spinner("Thinking..."):
63
+ inputs = tokenizer(
64
+ chat_prompt,
65
+ return_tensors="pt",
66
+ padding=True,
67
+ truncation=True,
68
+ max_length=2048
69
+ )
70
+
71
+ # Stream generation with progress tracking
72
+ generated_text = ""
73
+ generated_ids = []
74
+ progress_bar = st.progress(0)
75
+
76
+ for i in range(512): # Max tokens
77
+ try:
78
+ outputs = model.generate(
79
+ inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
80
+ max_new_tokens=1,
81
+ temperature=0.7,
82
+ do_sample=True,
83
+ top_p=0.95,
84
+ top_k=50,
85
+ repetition_penalty=1.1,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ attention_mask=torch.ones_like(inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1))
88
+ )
89
+
90
+ next_token = outputs[0][-1].item()
91
+ generated_ids.append(next_token)
92
+
93
+ # Update progress
94
+ progress = min(1.0, i / 512)
95
+ progress_bar.progress(progress)
96
+
97
+ # Decode and stream current output
98
+ current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
99
+
100
+ # Stream tokens smoothly
101
+ for chunk in stream_tokens(current_text[len(generated_text):]):
102
+ generated_text += chunk
103
+ with response_container:
104
+ message_placeholder.markdown(generated_text)
105
+
106
+ # Check stopping conditions
107
+ if (next_token == tokenizer.eos_token_id or
108
+ "### Human:" in current_text or
109
+ len(generated_ids) >= 512):
110
+ break
111
+
112
+ except torch.cuda.OutOfMemoryError:
113
+ torch.cuda.empty_cache()
114
+ st.warning("Memory limit reached, truncating response...")
115
+ break
116
+
117
+ progress_bar.empty()
118
+
119
+ # Clean and validate response
120
+ response = generated_text.split("### Assistant:")[-1].split("### Human:")[0].strip()
121
+ if len(response) < 10: # Minimum response length
122
+ raise ValueError("Generated response too short")
123
+
124
+ return response
125
+
126
+ except Exception as e:
127
+ logger = logging.getLogger(__name__)
128
+ logger.error(f"Generation error: {str(e)}")
129
+ st.error("Something went wrong. Please try again with a different prompt.")
130
+ return None