test / utils.py
Mohansai2004's picture
feat: implement DeepSeek Janus chat interface
96a17a6
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from typing import Generator, Optional
import time
logging.basicConfig(level=logging.INFO)
@st.cache_resource
def load_model():
if "model_loaded" not in st.session_state:
st.session_state.model_loaded = False
model_name = "deepseek-ai/Janus-Pro-7B"
try:
with st.spinner("🔄 Loading model (first run only)..."):
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
padding_side='left'
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map='cpu'
)
model.eval()
torch.set_num_threads(8)
st.session_state.model_loaded = True
return model, tokenizer
except Exception as e:
st.error(f"❌ Error loading model: {str(e)}")
st.info("Try refreshing the page or clearing the cache.")
st.stop()
def stream_tokens(response: str, delay: float = 0.01) -> Generator[str, None, None]:
"""Stream tokens with controlled delay for smooth output"""
buffer = ""
for char in response:
buffer += char
if len(buffer) >= 3 or char in '.!?': # Stream by chunks or punctuation
yield buffer
buffer = ""
time.sleep(delay)
if buffer: # Yield remaining text
yield buffer
def generate_stream(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Optional[str]:
try:
# Safety checks
if not model or not tokenizer:
raise ValueError("Model or tokenizer not initialized")
# Format prompt with safety checks
safe_prompt = prompt.strip().replace("<", "&lt;").replace(">", "&gt;")
chat_prompt = f"""### Human: {safe_prompt}
### Assistant: I'll help you with that."""
# Create persistent placeholder
message_placeholder = st.empty()
response_container = st.container()
with torch.inference_mode(), st.spinner("Thinking..."):
inputs = tokenizer(
chat_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Stream generation with progress tracking
generated_text = ""
generated_ids = []
progress_bar = st.progress(0)
for i in range(512): # Max tokens
try:
outputs = model.generate(
inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
max_new_tokens=1,
temperature=0.7,
do_sample=True,
top_p=0.95,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
attention_mask=torch.ones_like(inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1))
)
next_token = outputs[0][-1].item()
generated_ids.append(next_token)
# Update progress
progress = min(1.0, i / 512)
progress_bar.progress(progress)
# Decode and stream current output
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Stream tokens smoothly
for chunk in stream_tokens(current_text[len(generated_text):]):
generated_text += chunk
with response_container:
message_placeholder.markdown(generated_text)
# Check stopping conditions
if (next_token == tokenizer.eos_token_id or
"### Human:" in current_text or
len(generated_ids) >= 512):
break
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
st.warning("Memory limit reached, truncating response...")
break
progress_bar.empty()
# Clean and validate response
response = generated_text.split("### Assistant:")[-1].split("### Human:")[0].strip()
if len(response) < 10: # Minimum response length
raise ValueError("Generated response too short")
return response
except Exception as e:
st.error(f"Generation error: {str(e)}")
return "I apologize, but I couldn't generate a response. Please try again."
```