import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import logging from typing import Generator, Optional import time logging.basicConfig(level=logging.INFO) @st.cache_resource def load_model(): if "model_loaded" not in st.session_state: st.session_state.model_loaded = False model_name = "deepseek-ai/Janus-Pro-7B" try: with st.spinner("🔄 Loading model (first run only)..."): tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, padding_side='left' ) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, device_map='cpu' ) model.eval() torch.set_num_threads(8) st.session_state.model_loaded = True return model, tokenizer except Exception as e: st.error(f"❌ Error loading model: {str(e)}") st.info("Try refreshing the page or clearing the cache.") st.stop() def stream_tokens(response: str, delay: float = 0.01) -> Generator[str, None, None]: """Stream tokens with controlled delay for smooth output""" buffer = "" for char in response: buffer += char if len(buffer) >= 3 or char in '.!?': # Stream by chunks or punctuation yield buffer buffer = "" time.sleep(delay) if buffer: # Yield remaining text yield buffer def generate_stream(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Optional[str]: try: # Safety checks if not model or not tokenizer: raise ValueError("Model or tokenizer not initialized") # Format prompt with safety checks safe_prompt = prompt.strip().replace("<", "<").replace(">", ">") chat_prompt = f"""### Human: {safe_prompt} ### Assistant: I'll help you with that.""" # Create persistent placeholder message_placeholder = st.empty() response_container = st.container() with torch.inference_mode(), st.spinner("Thinking..."): inputs = tokenizer( chat_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048 ) # Stream generation with progress tracking generated_text = "" generated_ids = [] progress_bar = st.progress(0) for i in range(512): # Max tokens try: outputs = model.generate( inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1), max_new_tokens=1, temperature=0.7, do_sample=True, top_p=0.95, top_k=50, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id, attention_mask=torch.ones_like(inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1)) ) next_token = outputs[0][-1].item() generated_ids.append(next_token) # Update progress progress = min(1.0, i / 512) progress_bar.progress(progress) # Decode and stream current output current_text = tokenizer.decode(generated_ids, skip_special_tokens=True) # Stream tokens smoothly for chunk in stream_tokens(current_text[len(generated_text):]): generated_text += chunk with response_container: message_placeholder.markdown(generated_text) # Check stopping conditions if (next_token == tokenizer.eos_token_id or "### Human:" in current_text or len(generated_ids) >= 512): break except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() st.warning("Memory limit reached, truncating response...") break progress_bar.empty() # Clean and validate response response = generated_text.split("### Assistant:")[-1].split("### Human:")[0].strip() if len(response) < 10: # Minimum response length raise ValueError("Generated response too short") return response except Exception as e: st.error(f"Generation error: {str(e)}") return "I apologize, but I couldn't generate a response. Please try again." ```