File size: 7,711 Bytes
078bd3c
9ebaef7
 
 
 
5374b45
c380681
9ebaef7
 
 
 
 
5374b45
c380681
 
73c98fc
c380681
9ebaef7
c380681
7f1f6f2
c380681
9ebaef7
 
 
 
 
 
 
c380681
9ebaef7
c380681
9ebaef7
c380681
 
9ebaef7
 
 
c380681
9ebaef7
 
c380681
 
9ebaef7
 
c380681
9ebaef7
 
 
 
 
c380681
 
 
7f1f6f2
9ebaef7
c380681
9ebaef7
7f1f6f2
9ebaef7
 
 
c380681
5881ab0
9ebaef7
c380681
9ebaef7
c380681
 
9ebaef7
 
c380681
9ebaef7
 
 
c380681
9ebaef7
 
c380681
 
9ebaef7
 
c380681
9ebaef7
 
c380681
9ebaef7
c380681
9ebaef7
 
c380681
9ebaef7
 
 
 
c380681
 
9ebaef7
 
c380681
 
 
9ebaef7
c380681
 
 
 
7f1f6f2
c380681
9ebaef7
 
 
c380681
9ebaef7
8665c7a
9ebaef7
 
7f1f6f2
c380681
9ebaef7
 
7f1f6f2
c380681
9ebaef7
 
c380681
 
49dc795
c380681
 
 
 
 
 
49dc795
c380681
 
 
 
49dc795
7f1f6f2
d9a1250
c380681
 
 
e916305
c380681
 
 
7f1f6f2
c380681
9ebaef7
c380681
 
 
 
 
 
 
 
9ebaef7
c380681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ebaef7
c380681
 
49dc795
5881ab0
7f1f6f2
c380681
 
 
 
 
9ebaef7
c380681
 
9ebaef7
c380681
9ebaef7
 
 
c380681
9ebaef7
49dc795
9ebaef7
 
 
 
c380681
9ebaef7
 
 
c380681
9ebaef7
d9a1250
c380681
d9a1250
9ebaef7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import traceback
import time
from huggingface_hub import snapshot_download
import gradio as gr

# Attempt to import llama_cpp, if failed, prompt in the UI
try:
    from llama_cpp import Llama
except Exception as e:
    Llama = None
    Llama_import_error = e

# ---------- Configuration Area ----------
# ★★★ Please change this to your model repository ★★★
MODEL_REPO = "Marcus719/Llama-3.2-3B-changedata-Lab2-GGUF"
# Specify to download only the q4_k_m file to prevent running out of disk space
GGUF_FILENAME = "unsloth.Q4_K_M.gguf" 
DEFAULT_N_CTX = 2048  # Context length
DEFAULT_MAX_TOKENS = 256 # Default generation length
DEFAULT_N_THREADS = 2 # Recommended 2 for free CPU tier
# ------------------------------

def log(msg: str):
    print(f"[app] {time.strftime('%Y-%m-%d %H:%M:%S')} - {msg}", flush=True)

def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
    if Llama is None:
        raise RuntimeError(f"llama-cpp-python not installed or failed to load: {Llama_import_error}")
    
    log(f"Starting model download: {repo_id} / {filename} ...")
    
    # Use snapshot_download to download a single file
    # allow_patterns ensures only the GGUF file is downloaded
    local_dir = snapshot_download(
        repo_id=repo_id, 
        allow_patterns=[filename],
        local_dir_use_symlinks=False # Disabling symlinks for stability in Spaces
    )
    
    # Construct full path
    # snapshot_download usually preserves directory structure, otherwise we search
    gguf_path = os.path.join(local_dir, filename)
    
    # Search for the file if direct path fails (for robustness)
    if not os.path.exists(gguf_path):
        for root, dirs, files in os.walk(local_dir):
            if filename in files:
                gguf_path = os.path.join(root, filename)
                break
        if not os.path.exists(gguf_path):
            raise FileNotFoundError(f"Could not find {filename} in {local_dir}")
            
    log(f"Model path: {gguf_path}. Loading into memory...")
    
    # Initialize the model
    llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
    log("Llama model loaded successfully!")
    return llm, gguf_path

def init_model(state):
    """Callback function for the Load button"""
    try:
        if state.get("llm") is not None:
            return state
        
        log("Received load request...")
        # Download and load
        llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
        
        # Update state
        state["llm"] = llm
        state["gguf_path"] = gguf_path
        
        return state
    except Exception as exc:
        tb = traceback.format_exc()
        log(f"Initialization Error: {exc}\n{tb}")
        return state

def generate_response(prompt: str, max_tokens: int, state):
    """Callback function for the Generate button"""
    try:
        if not prompt or prompt.strip() == "":
            return "Please enter an instruction.", state
        
        # Lazy loading: attempt to auto-load if Generate is clicked without explicit initialization
        if state.get("llm") is None:
            try:
                log("Model not detected, attempting auto-load...")
                llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
                state["llm"] = llm
                state["gguf_path"] = gguf_path
            except Exception as e:
                return f"Model Load Failed: {e}", state
        
        llm = state.get("llm")
        
        log(f"Generating (Prompt Length={len(prompt)})...")
        
        # Construct Llama 3 format Prompt
        system_prompt = "You are a helpful AI assistant."
        # Simple concatenation: System + User
        # For strict formatting, use tokenizer.apply_chat_template
        # Using simple text concatenation here for generality, Llama 3 usually understands
        full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        
        # Inference
        output = llm(
            full_prompt, 
            max_tokens=max_tokens, 
            stop=["<|eot_id|>"], # Stop token
            echo=False
        )
        
        text = output['choices'][0]['text']
        log("Generation complete.")
        return text, state
    except Exception as exc:
        tb = traceback.format_exc()
        log(f"Generation Error: {exc}\n{tb}")
        return f"Runtime Error: {exc}", state

def soft_clear(current_state):
    """Clear button: only clears text, keeps the model loaded"""
    return "", current_state

# ---------------- Gradio UI Construction ----------------
# Theme settings
theme = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="slate",
    neutral_hue="slate")

# Custom CSS
custom_css = """.footer-text { font-size: 0.8em; color: gray; text-align: center; }"""

with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
    
    # Header
    with gr.Row():
        gr.Markdown("# Llama 3.2 (1B) Fine-Tuned Chatbot")
        gr.Markdown(
            f"""
            **ID2223 Lab 2 Project** | Fine-tuned on **UltraChat-200k-Filtered(only use 100k)**.
            Running on CPU (GGUF 4-bit) | Model: `{MODEL_REPO}`
            """
        )

    # Main layout
    with gr.Row():
        # Left: Input and Controls
        with gr.Column(scale=4):
            with gr.Group():
                prompt_in = gr.Textbox(
                    lines=5, 
                    label="User Instruction (User Input)", 
                    placeholder="e.g., Explain Quantum Mechanics...",
                    elem_id="prompt-input"
                )
                                
                with gr.Accordion("Advanced Parameters", open=False):
                    max_tokens = gr.Slider(
                        minimum=16, 
                        maximum=1024, 
                        step=16, 
                        value=DEFAULT_MAX_TOKENS, 
                        label="Max Generation Length (Max Tokens)",
                        info="Longer generations will take more CPU time."
                    )
                        
                with gr.Row():
                    init_btn = gr.Button("1. Load Model", variant="secondary")
                    gen_btn = gr.Button("2. Generate Response", variant="primary")
                        
                clear_btn = gr.Button("Clear Chat", variant="stop")

        # Right: Output Display
        with gr.Column(scale=6):
            output_txt = gr.Textbox(
                label="Model Response (Response)", 
                lines=15, 
            )

    # Footer
    with gr.Row():
        gr.Markdown(
            "*Note: Inference runs on a free CPU, so speed may be slow. The model (approx. 2GB) must be downloaded on first run, please be patient.*",
            elem_classes=["footer-text"]
        )

    # State storage
    state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})

    # Event binding
    init_btn.click(
        fn=init_model, 
        inputs=state, 
        outputs=[state],
        show_progress=True
    )
    
    gen_btn.click(
        fn=generate_response, 
        inputs=[prompt_in, max_tokens, state], 
        outputs=[output_txt, state],
        show_progress=True
    )
    
    clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, state])
    clear_btn.click(lambda: "", outputs=[output_txt])

# Launch the application
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)