Spaces:
Sleeping
Sleeping
File size: 5,428 Bytes
cc807a2 96a17a6 cc807a2 96a17a6 cc807a2 96a17a6 cc807a2 96a17a6 cc807a2 96a17a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from typing import Generator, Optional
import time
logging.basicConfig(level=logging.INFO)
@st.cache_resource
def load_model():
if "model_loaded" not in st.session_state:
st.session_state.model_loaded = False
model_name = "deepseek-ai/Janus-Pro-7B"
try:
with st.spinner("🔄 Loading model (first run only)..."):
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
padding_side='left'
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map='cpu'
)
model.eval()
torch.set_num_threads(8)
st.session_state.model_loaded = True
return model, tokenizer
except Exception as e:
st.error(f"❌ Error loading model: {str(e)}")
st.info("Try refreshing the page or clearing the cache.")
st.stop()
def stream_tokens(response: str, delay: float = 0.01) -> Generator[str, None, None]:
"""Stream tokens with controlled delay for smooth output"""
buffer = ""
for char in response:
buffer += char
if len(buffer) >= 3 or char in '.!?': # Stream by chunks or punctuation
yield buffer
buffer = ""
time.sleep(delay)
if buffer: # Yield remaining text
yield buffer
def generate_stream(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Optional[str]:
try:
# Safety checks
if not model or not tokenizer:
raise ValueError("Model or tokenizer not initialized")
# Format prompt with safety checks
safe_prompt = prompt.strip().replace("<", "<").replace(">", ">")
chat_prompt = f"""### Human: {safe_prompt}
### Assistant: I'll help you with that."""
# Create persistent placeholder
message_placeholder = st.empty()
response_container = st.container()
with torch.inference_mode(), st.spinner("Thinking..."):
inputs = tokenizer(
chat_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Stream generation with progress tracking
generated_text = ""
generated_ids = []
progress_bar = st.progress(0)
for i in range(512): # Max tokens
try:
outputs = model.generate(
inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
max_new_tokens=1,
temperature=0.7,
do_sample=True,
top_p=0.95,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
attention_mask=torch.ones_like(inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1))
)
next_token = outputs[0][-1].item()
generated_ids.append(next_token)
# Update progress
progress = min(1.0, i / 512)
progress_bar.progress(progress)
# Decode and stream current output
current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Stream tokens smoothly
for chunk in stream_tokens(current_text[len(generated_text):]):
generated_text += chunk
with response_container:
message_placeholder.markdown(generated_text)
# Check stopping conditions
if (next_token == tokenizer.eos_token_id or
"### Human:" in current_text or
len(generated_ids) >= 512):
break
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache()
st.warning("Memory limit reached, truncating response...")
break
progress_bar.empty()
# Clean and validate response
response = generated_text.split("### Assistant:")[-1].split("### Human:")[0].strip()
if len(response) < 10: # Minimum response length
raise ValueError("Generated response too short")
return response
except Exception as e:
st.error(f"Generation error: {str(e)}")
return "I apologize, but I couldn't generate a response. Please try again."
```
|