File size: 5,428 Bytes
cc807a2
 
 
 
 
 
 
 
 
 
 
96a17a6
 
 
cc807a2
 
 
96a17a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc807a2
96a17a6
 
cc807a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a17a6
 
 
 
cc807a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a17a6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from typing import Generator, Optional
import time

logging.basicConfig(level=logging.INFO)

@st.cache_resource
def load_model():
    if "model_loaded" not in st.session_state:
        st.session_state.model_loaded = False
        
    model_name = "deepseek-ai/Janus-Pro-7B"
    
    try:
        with st.spinner("🔄 Loading model (first run only)..."):
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                padding_side='left'
            )
            tokenizer.pad_token = tokenizer.eos_token
            
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
                device_map='cpu'
            )
            
            model.eval()
            torch.set_num_threads(8)
            st.session_state.model_loaded = True
            
            return model, tokenizer
            
    except Exception as e:
        st.error(f"❌ Error loading model: {str(e)}")
        st.info("Try refreshing the page or clearing the cache.")
        st.stop()

def stream_tokens(response: str, delay: float = 0.01) -> Generator[str, None, None]:
    """Stream tokens with controlled delay for smooth output"""
    buffer = ""
    for char in response:
        buffer += char
        if len(buffer) >= 3 or char in '.!?':  # Stream by chunks or punctuation
            yield buffer
            buffer = ""
            time.sleep(delay)
    if buffer:  # Yield remaining text
        yield buffer

def generate_stream(prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Optional[str]:
    try:
        # Safety checks
        if not model or not tokenizer:
            raise ValueError("Model or tokenizer not initialized")
            
        # Format prompt with safety checks
        safe_prompt = prompt.strip().replace("<", "&lt;").replace(">", "&gt;")
        chat_prompt = f"""### Human: {safe_prompt}

### Assistant: I'll help you with that."""

        # Create persistent placeholder
        message_placeholder = st.empty()
        response_container = st.container()
        
        with torch.inference_mode(), st.spinner("Thinking..."):
            inputs = tokenizer(
                chat_prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048
            )

            # Stream generation with progress tracking
            generated_text = ""
            generated_ids = []
            progress_bar = st.progress(0)
            
            for i in range(512):  # Max tokens
                try:
                    outputs = model.generate(
                        inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1),
                        max_new_tokens=1,
                        temperature=0.7,
                        do_sample=True,
                        top_p=0.95,
                        top_k=50,
                        repetition_penalty=1.1,
                        pad_token_id=tokenizer.eos_token_id,
                        attention_mask=torch.ones_like(inputs["input_ids"] if not generated_ids else torch.cat([inputs["input_ids"], torch.tensor([generated_ids]).to(model.device)], dim=1))
                    )
                    
                    next_token = outputs[0][-1].item()
                    generated_ids.append(next_token)
                    
                    # Update progress
                    progress = min(1.0, i / 512)
                    progress_bar.progress(progress)
                    
                    # Decode and stream current output
                    current_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
                    
                    # Stream tokens smoothly
                    for chunk in stream_tokens(current_text[len(generated_text):]):
                        generated_text += chunk
                        with response_container:
                            message_placeholder.markdown(generated_text)
                    
                    # Check stopping conditions
                    if (next_token == tokenizer.eos_token_id or 
                        "### Human:" in current_text or 
                        len(generated_ids) >= 512):
                        break
                        
                except torch.cuda.OutOfMemoryError:
                    torch.cuda.empty_cache()
                    st.warning("Memory limit reached, truncating response...")
                    break
                    
            progress_bar.empty()
            
            # Clean and validate response
            response = generated_text.split("### Assistant:")[-1].split("### Human:")[0].strip()
            if len(response) < 10:  # Minimum response length
                raise ValueError("Generated response too short")
                
            return response
            
    except Exception as e:
        st.error(f"Generation error: {str(e)}")
        return "I apologize, but I couldn't generate a response. Please try again."
```