ID2223_Lab2 / app.py
Marcus719's picture
Update app.py
e916305 verified
import os
import traceback
import time
from huggingface_hub import snapshot_download
import gradio as gr
# Attempt to import llama_cpp, if failed, prompt in the UI
try:
from llama_cpp import Llama
except Exception as e:
Llama = None
Llama_import_error = e
# ---------- Configuration Area ----------
# ★★★ Please change this to your model repository ★★★
MODEL_REPO = "Marcus719/Llama-3.2-3B-changedata-Lab2-GGUF"
# Specify to download only the q4_k_m file to prevent running out of disk space
GGUF_FILENAME = "unsloth.Q4_K_M.gguf"
DEFAULT_N_CTX = 2048 # Context length
DEFAULT_MAX_TOKENS = 256 # Default generation length
DEFAULT_N_THREADS = 2 # Recommended 2 for free CPU tier
# ------------------------------
def log(msg: str):
print(f"[app] {time.strftime('%Y-%m-%d %H:%M:%S')} - {msg}", flush=True)
def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS):
if Llama is None:
raise RuntimeError(f"llama-cpp-python not installed or failed to load: {Llama_import_error}")
log(f"Starting model download: {repo_id} / {filename} ...")
# Use snapshot_download to download a single file
# allow_patterns ensures only the GGUF file is downloaded
local_dir = snapshot_download(
repo_id=repo_id,
allow_patterns=[filename],
local_dir_use_symlinks=False # Disabling symlinks for stability in Spaces
)
# Construct full path
# snapshot_download usually preserves directory structure, otherwise we search
gguf_path = os.path.join(local_dir, filename)
# Search for the file if direct path fails (for robustness)
if not os.path.exists(gguf_path):
for root, dirs, files in os.walk(local_dir):
if filename in files:
gguf_path = os.path.join(root, filename)
break
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"Could not find {filename} in {local_dir}")
log(f"Model path: {gguf_path}. Loading into memory...")
# Initialize the model
llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False)
log("Llama model loaded successfully!")
return llm, gguf_path
def init_model(state):
"""Callback function for the Load button"""
try:
if state.get("llm") is not None:
return state
log("Received load request...")
# Download and load
llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
# Update state
state["llm"] = llm
state["gguf_path"] = gguf_path
return state
except Exception as exc:
tb = traceback.format_exc()
log(f"Initialization Error: {exc}\n{tb}")
return state
def generate_response(prompt: str, max_tokens: int, state):
"""Callback function for the Generate button"""
try:
if not prompt or prompt.strip() == "":
return "Please enter an instruction.", state
# Lazy loading: attempt to auto-load if Generate is clicked without explicit initialization
if state.get("llm") is None:
try:
log("Model not detected, attempting auto-load...")
llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME)
state["llm"] = llm
state["gguf_path"] = gguf_path
except Exception as e:
return f"Model Load Failed: {e}", state
llm = state.get("llm")
log(f"Generating (Prompt Length={len(prompt)})...")
# Construct Llama 3 format Prompt
system_prompt = "You are a helpful AI assistant."
# Simple concatenation: System + User
# For strict formatting, use tokenizer.apply_chat_template
# Using simple text concatenation here for generality, Llama 3 usually understands
full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
# Inference
output = llm(
full_prompt,
max_tokens=max_tokens,
stop=["<|eot_id|>"], # Stop token
echo=False
)
text = output['choices'][0]['text']
log("Generation complete.")
return text, state
except Exception as exc:
tb = traceback.format_exc()
log(f"Generation Error: {exc}\n{tb}")
return f"Runtime Error: {exc}", state
def soft_clear(current_state):
"""Clear button: only clears text, keeps the model loaded"""
return "", current_state
# ---------------- Gradio UI Construction ----------------
# Theme settings
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="slate",
neutral_hue="slate")
# Custom CSS
custom_css = """.footer-text { font-size: 0.8em; color: gray; text-align: center; }"""
with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo:
# Header
with gr.Row():
gr.Markdown("# Llama 3.2 (1B) Fine-Tuned Chatbot")
gr.Markdown(
f"""
**ID2223 Lab 2 Project** | Fine-tuned on **UltraChat-200k-Filtered(only use 100k)**.
Running on CPU (GGUF 4-bit) | Model: `{MODEL_REPO}`
"""
)
# Main layout
with gr.Row():
# Left: Input and Controls
with gr.Column(scale=4):
with gr.Group():
prompt_in = gr.Textbox(
lines=5,
label="User Instruction (User Input)",
placeholder="e.g., Explain Quantum Mechanics...",
elem_id="prompt-input"
)
with gr.Accordion("Advanced Parameters", open=False):
max_tokens = gr.Slider(
minimum=16,
maximum=1024,
step=16,
value=DEFAULT_MAX_TOKENS,
label="Max Generation Length (Max Tokens)",
info="Longer generations will take more CPU time."
)
with gr.Row():
init_btn = gr.Button("1. Load Model", variant="secondary")
gen_btn = gr.Button("2. Generate Response", variant="primary")
clear_btn = gr.Button("Clear Chat", variant="stop")
# Right: Output Display
with gr.Column(scale=6):
output_txt = gr.Textbox(
label="Model Response (Response)",
lines=15,
)
# Footer
with gr.Row():
gr.Markdown(
"*Note: Inference runs on a free CPU, so speed may be slow. The model (approx. 2GB) must be downloaded on first run, please be patient.*",
elem_classes=["footer-text"]
)
# State storage
state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"})
# Event binding
init_btn.click(
fn=init_model,
inputs=state,
outputs=[state],
show_progress=True
)
gen_btn.click(
fn=generate_response,
inputs=[prompt_in, max_tokens, state],
outputs=[output_txt, state],
show_progress=True
)
clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, state])
clear_btn.click(lambda: "", outputs=[output_txt])
# Launch the application
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)