Spaces:
Sleeping
Sleeping
| import os | |
| import traceback | |
| import time | |
| from huggingface_hub import snapshot_download | |
| import gradio as gr | |
| # Attempt to import llama_cpp, if failed, prompt in the UI | |
| try: | |
| from llama_cpp import Llama | |
| except Exception as e: | |
| Llama = None | |
| Llama_import_error = e | |
| # ---------- Configuration Area ---------- | |
| # ★★★ Please change this to your model repository ★★★ | |
| MODEL_REPO = "Marcus719/Llama-3.2-3B-changedata-Lab2-GGUF" | |
| # Specify to download only the q4_k_m file to prevent running out of disk space | |
| GGUF_FILENAME = "unsloth.Q4_K_M.gguf" | |
| DEFAULT_N_CTX = 2048 # Context length | |
| DEFAULT_MAX_TOKENS = 256 # Default generation length | |
| DEFAULT_N_THREADS = 2 # Recommended 2 for free CPU tier | |
| # ------------------------------ | |
| def log(msg: str): | |
| print(f"[app] {time.strftime('%Y-%m-%d %H:%M:%S')} - {msg}", flush=True) | |
| def load_model_from_hub(repo_id: str, filename: str, n_ctx=DEFAULT_N_CTX, n_threads=DEFAULT_N_THREADS): | |
| if Llama is None: | |
| raise RuntimeError(f"llama-cpp-python not installed or failed to load: {Llama_import_error}") | |
| log(f"Starting model download: {repo_id} / {filename} ...") | |
| # Use snapshot_download to download a single file | |
| # allow_patterns ensures only the GGUF file is downloaded | |
| local_dir = snapshot_download( | |
| repo_id=repo_id, | |
| allow_patterns=[filename], | |
| local_dir_use_symlinks=False # Disabling symlinks for stability in Spaces | |
| ) | |
| # Construct full path | |
| # snapshot_download usually preserves directory structure, otherwise we search | |
| gguf_path = os.path.join(local_dir, filename) | |
| # Search for the file if direct path fails (for robustness) | |
| if not os.path.exists(gguf_path): | |
| for root, dirs, files in os.walk(local_dir): | |
| if filename in files: | |
| gguf_path = os.path.join(root, filename) | |
| break | |
| if not os.path.exists(gguf_path): | |
| raise FileNotFoundError(f"Could not find {filename} in {local_dir}") | |
| log(f"Model path: {gguf_path}. Loading into memory...") | |
| # Initialize the model | |
| llm = Llama(model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, verbose=False) | |
| log("Llama model loaded successfully!") | |
| return llm, gguf_path | |
| def init_model(state): | |
| """Callback function for the Load button""" | |
| try: | |
| if state.get("llm") is not None: | |
| return state | |
| log("Received load request...") | |
| # Download and load | |
| llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME) | |
| # Update state | |
| state["llm"] = llm | |
| state["gguf_path"] = gguf_path | |
| return state | |
| except Exception as exc: | |
| tb = traceback.format_exc() | |
| log(f"Initialization Error: {exc}\n{tb}") | |
| return state | |
| def generate_response(prompt: str, max_tokens: int, state): | |
| """Callback function for the Generate button""" | |
| try: | |
| if not prompt or prompt.strip() == "": | |
| return "Please enter an instruction.", state | |
| # Lazy loading: attempt to auto-load if Generate is clicked without explicit initialization | |
| if state.get("llm") is None: | |
| try: | |
| log("Model not detected, attempting auto-load...") | |
| llm, gguf_path = load_model_from_hub(MODEL_REPO, GGUF_FILENAME) | |
| state["llm"] = llm | |
| state["gguf_path"] = gguf_path | |
| except Exception as e: | |
| return f"Model Load Failed: {e}", state | |
| llm = state.get("llm") | |
| log(f"Generating (Prompt Length={len(prompt)})...") | |
| # Construct Llama 3 format Prompt | |
| system_prompt = "You are a helpful AI assistant." | |
| # Simple concatenation: System + User | |
| # For strict formatting, use tokenizer.apply_chat_template | |
| # Using simple text concatenation here for generality, Llama 3 usually understands | |
| full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | |
| # Inference | |
| output = llm( | |
| full_prompt, | |
| max_tokens=max_tokens, | |
| stop=["<|eot_id|>"], # Stop token | |
| echo=False | |
| ) | |
| text = output['choices'][0]['text'] | |
| log("Generation complete.") | |
| return text, state | |
| except Exception as exc: | |
| tb = traceback.format_exc() | |
| log(f"Generation Error: {exc}\n{tb}") | |
| return f"Runtime Error: {exc}", state | |
| def soft_clear(current_state): | |
| """Clear button: only clears text, keeps the model loaded""" | |
| return "", current_state | |
| # ---------------- Gradio UI Construction ---------------- | |
| # Theme settings | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="slate", | |
| neutral_hue="slate") | |
| # Custom CSS | |
| custom_css = """.footer-text { font-size: 0.8em; color: gray; text-align: center; }""" | |
| with gr.Blocks(title="Llama 3.2 Lab2 Project") as demo: | |
| # Header | |
| with gr.Row(): | |
| gr.Markdown("# Llama 3.2 (1B) Fine-Tuned Chatbot") | |
| gr.Markdown( | |
| f""" | |
| **ID2223 Lab 2 Project** | Fine-tuned on **UltraChat-200k-Filtered(only use 100k)**. | |
| Running on CPU (GGUF 4-bit) | Model: `{MODEL_REPO}` | |
| """ | |
| ) | |
| # Main layout | |
| with gr.Row(): | |
| # Left: Input and Controls | |
| with gr.Column(scale=4): | |
| with gr.Group(): | |
| prompt_in = gr.Textbox( | |
| lines=5, | |
| label="User Instruction (User Input)", | |
| placeholder="e.g., Explain Quantum Mechanics...", | |
| elem_id="prompt-input" | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=16, | |
| maximum=1024, | |
| step=16, | |
| value=DEFAULT_MAX_TOKENS, | |
| label="Max Generation Length (Max Tokens)", | |
| info="Longer generations will take more CPU time." | |
| ) | |
| with gr.Row(): | |
| init_btn = gr.Button("1. Load Model", variant="secondary") | |
| gen_btn = gr.Button("2. Generate Response", variant="primary") | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| # Right: Output Display | |
| with gr.Column(scale=6): | |
| output_txt = gr.Textbox( | |
| label="Model Response (Response)", | |
| lines=15, | |
| ) | |
| # Footer | |
| with gr.Row(): | |
| gr.Markdown( | |
| "*Note: Inference runs on a free CPU, so speed may be slow. The model (approx. 2GB) must be downloaded on first run, please be patient.*", | |
| elem_classes=["footer-text"] | |
| ) | |
| # State storage | |
| state = gr.State({"llm": None, "gguf_path": None, "status": "Not initialized"}) | |
| # Event binding | |
| init_btn.click( | |
| fn=init_model, | |
| inputs=state, | |
| outputs=[state], | |
| show_progress=True | |
| ) | |
| gen_btn.click( | |
| fn=generate_response, | |
| inputs=[prompt_in, max_tokens, state], | |
| outputs=[output_txt, state], | |
| show_progress=True | |
| ) | |
| clear_btn.click(fn=soft_clear, inputs=[state], outputs=[prompt_in, state]) | |
| clear_btn.click(lambda: "", outputs=[output_txt]) | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |