Spaces:
Runtime error
Runtime error
| # ------------------------------------------------- | |
| # app.py – a Gradio demo that works on CPU only | |
| # ------------------------------------------------- | |
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| pipeline, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 1️⃣ Choose your model ------------------------------------------------ | |
| # ------------------------------------------------------------------ | |
| REPO_ID = "EleutherAI/gpt-neo-1.3B" # <-- change this to any model below | |
| # ------------------------------------------------------------------ | |
| # 2️⃣ Choose quantisation mode ----------------------------------------- | |
| # None → fp16 (more RAM, no quantisation) | |
| # "4bit" → 4‑bit inference via bitsandbytes (recommended for >1B) | |
| # "8bit" → 8‑bit inference via bitsandbytes (good for 6‑12B) | |
| # ------------------------------------------------------------------ | |
| QUANT = "4bit" # set to "8bit" if you pick a 6‑12B model, else None | |
| # ------------------------------------------------------------------ | |
| def load_model(): | |
| """ | |
| Loads the model with the appropriate bitsandbytes flags. | |
| No `bitsandbytes.quantize` import is needed. | |
| """ | |
| # ----- tokenizer ------------------------------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained(REPO_ID, use_fast=True) | |
| # ----- model loading options -------------------------------------- | |
| model_kwargs = { | |
| "torch_dtype": torch.float16, # fp16 works on modern CPUs; falls back to fp32 if not supported | |
| "device_map": "cpu", # explicit – makes the code clearer | |
| } | |
| if QUANT == "4bit": | |
| # 4‑bit GPTQ (nf4) – works on pure CPU | |
| model_kwargs.update( | |
| { | |
| "load_in_4bit": True, | |
| "bnb_4bit_compute_dtype": torch.float16, | |
| "bnb_4bit_use_double_quant": True, | |
| "bnb_4bit_quant_type": "nf4", | |
| } | |
| ) | |
| elif QUANT == "8bit": | |
| # 8‑bit inference – a little less memory‑efficient but sometimes faster on older CPUs | |
| model_kwargs["load_in_8bit"] = True | |
| # ----- actual model ----------------------------------------------- | |
| model = AutoModelForCausalLM.from_pretrained(REPO_ID, **model_kwargs) | |
| # ----- pipeline ---------------------------------------------------- | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| # keep the generation on CPU – no `device` argument needed | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| do_sample=True, | |
| ) | |
| return generator | |
| # ------------------------------------------------------------------ | |
| generator = load_model() | |
| # ------------------------------------------------------------------ | |
| def chat(user_msg, history): | |
| """ | |
| Gradio chat callback. | |
| `history` is a list of (user, bot) tuples. | |
| """ | |
| # Build a single prompt that contains the whole conversation. | |
| # This is the simplest approach; you can also use `tokenizer.encode` | |
| # with `add_special_tokens=False` if you need tighter control. | |
| prompt = "" | |
| for u, b in history: | |
| prompt += f"User: {u}\nAssistant: {b}\n" | |
| prompt += f"User: {user_msg}\nAssistant:" | |
| # Generate | |
| result = generator(prompt)[0]["generated_text"] | |
| # The model returns the full text (prompt + answer). We slice out the answer. | |
| answer = result.split("Assistant:")[-1].strip() | |
| # In case the model accidentally repeats "User:" we cut at the first occurrence. | |
| answer = answer.split("\nUser:")[0].strip() | |
| # Append to chat history and return | |
| history.append((user_msg, answer)) | |
| return "", history | |
| # ------------------------------------------------------------------ | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="🤗 CPU‑only LLM Demo", | |
| description=( | |
| f"Model: **{REPO_ID}** " | |
| f"{'(4‑bit)' if QUANT=='4bit' else '(8‑bit)' if QUANT=='8bit' else '(fp16)'}" | |
| ), | |
| theme="default", | |
| ) | |
| if __name__ == "__main__": | |
| # Gradio will automatically pick the right port for a Space. | |
| demo.launch() |