Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import logging | |
| from typing import Generator, Optional | |
| import time | |
| logging.basicConfig(level=logging.INFO) | |
| def load_model(): | |
| if "model_loaded" not in st.session_state: | |
| st.session_state.model_loaded = False | |
| model_name = "deepseek-ai/Janus-Pro-7B" | |
| try: | |
| with st.spinner("🔄 Loading model (first run only)..."): | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| padding_side='left' | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| device_map='cpu' | |
| ) | |
| model.eval() | |
| torch.set_num_threads(8) | |
| st.session_state.model_loaded = True | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"❌ Error loading model: {str(e)}") | |
| st.info("Try refreshing the page or clearing the cache.") | |
| st.stop() | |
| def stream_tokens(response: str, delay: float = 0.01) -> Generator[str, None, None]: | |
| """Stream tokens with controlled delay for smooth output""" | |
| buffer = "" | |
| for char in response: | |
| buffer += char | |
| if len(buffer) >= 3 or char in '.!?': # Stream by chunks or punctuation | |
| yield buffer | |
| buffer = "" | |
| time.sleep(delay) | |
| if buffer: # Yield remaining text | |
| yield buffer | |
| def generate_stream(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Optional[str]: | |
| try: | |
| # Safety checks | |
| if not model or not tokenizer: | |
| raise ValueError("Model or tokenizer not initialized") | |
| # Format prompt with safety checks | |
| safe_prompt = prompt.strip().replace("<", "<").replace(">", ">") | |
| chat_prompt = f"""### Human: {safe_prompt} | |
| ### Assistant: I'll help you with that.""" | |
| # Create persistent placeholder | |
| message_placeholder = st.empty() | |
| response_container = st.container() | |
| with torch.inference_mode(), st.spinner("Thinking..."): | |
| inputs = tokenizer( | |
| chat_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048 | |
| ) | |
| # Stream generation with progress tracking | |
| generated_text = "" | |
| generated_ids = [] | |
| progress_bar = st.progress(0) | |
| for i in range(512): # Max tokens | |
| try: | |
| outputs = model.generate( | |
| inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1), | |
| max_new_tokens=1, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| attention_mask=torch.ones_like(inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1)) | |
| ) | |
| next_token = outputs[0][-1].item() | |
| generated_ids.append(next_token) | |
| # Update progress | |
| progress = min(1.0, i / 512) | |
| progress_bar.progress(progress) | |
| # Decode and stream current output | |
| current_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Stream tokens smoothly | |
| for chunk in stream_tokens(current_text[len(generated_text):]): | |
| generated_text += chunk | |
| with response_container: | |
| message_placeholder.markdown(generated_text) | |
| # Check stopping conditions | |
| if (next_token == tokenizer.eos_token_id or | |
| "### Human:" in current_text or | |
| len(generated_ids) >= 512): | |
| break | |
| except torch.cuda.OutOfMemoryError: | |
| torch.cuda.empty_cache() | |
| st.warning("Memory limit reached, truncating response...") | |
| break | |
| progress_bar.empty() | |
| # Clean and validate response | |
| response = generated_text.split("### Assistant:")[-1].split("### Human:")[0].strip() | |
| if len(response) < 10: # Minimum response length | |
| raise ValueError("Generated response too short") | |
| return response | |
| except Exception as e: | |
| st.error(f"Generation error: {str(e)}") | |
| return "I apologize, but I couldn't generate a response. Please try again." | |
| ``` | |