Spaces:
Sleeping
Sleeping
| # app.py — ZeroGPU-optimised Gradio app (HF Spaces) — refined | |
| import os | |
| import tempfile | |
| from datetime import datetime | |
| import gradio as gr | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ---- Small env tweak: faster hub downloads when available ---- | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| # ---- ZeroGPU decorator ---- | |
| try: | |
| import spaces # HF Spaces utility (provides @spaces.GPU()) | |
| except Exception: | |
| class _Noop: | |
| def GPU(self, *args, **kwargs): | |
| def deco(fn): return fn | |
| return deco | |
| spaces = _Noop() | |
| # ---- Optional quantisation (GPU only) ---- | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| HAS_BNB = True | |
| except Exception: | |
| HAS_BNB = False | |
| # ---- Optional Flash-Attention 2 ---- | |
| _HAS_FLASH = False | |
| try: | |
| import flash_attn # noqa: F401 | |
| _HAS_FLASH = True | |
| except Exception: | |
| _HAS_FLASH = False | |
| # ---------------------------- | |
| # Config | |
| # ---------------------------- | |
| DEFAULT_MODELS = [ | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "neovalle/tinyllama-1.1B-h4rmony-trained", | |
| ] | |
| # Keep batches reasonable on ZeroGPU for low latency | |
| MICROBATCH_CPU = 2 | |
| MICROBATCH_GPU = 6 # H200 can handle a bit more than 4 for tiny models | |
| # Cap encoder length to avoid wasting time on very long inputs | |
| MAX_INPUT_TOKENS = 1024 | |
| MAX_NEW_TOKENS_HARD_CAP = 1024 # extra guardrail | |
| # Speed on GPU (TF32 gives extra throughput on Ampere+) | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # hint PyTorch to pick faster kernels when legal | |
| try: | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| else: | |
| # On CPU, reducing threads sometimes helps stability/predictability | |
| try: | |
| torch.set_num_threads(max(1, (os.cpu_count() or 4) // 2)) | |
| except Exception: | |
| pass | |
| _MODEL_CACHE = {} # cache: model_id -> (tokenizer, model) | |
| # ---------------------------- | |
| # Helpers | |
| # ---------------------------- | |
| def _all_eos_ids(tok): | |
| """Collect a few likely EOS ids so generation can stop earlier.""" | |
| ids = set() | |
| if tok.eos_token_id is not None: | |
| ids.add(tok.eos_token_id) | |
| for t in ("<|im_end|>", "<|endoftext|>", "</s>"): | |
| try: | |
| tid = tok.convert_tokens_to_ids(t) | |
| if isinstance(tid, int) and tid >= 0: | |
| ids.add(tid) | |
| except Exception: | |
| pass | |
| return list(ids) if ids else None | |
| def _load_model(model_id: str): | |
| """Load & cache model/tokenizer. On GPU, prefer 4-bit NF4 with BF16 compute.""" | |
| if model_id in _MODEL_CACHE: | |
| return _MODEL_CACHE[model_id] | |
| tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
| # Tokenizer hygiene | |
| if tok.pad_token is None: | |
| if tok.eos_token is not None: | |
| tok.pad_token = tok.eos_token | |
| else: | |
| tok.add_special_tokens({"pad_token": "<|pad|>"}) | |
| # Left padding plays nicer with causal models and kv-cache in batched gen | |
| try: | |
| tok.padding_side = "left" | |
| except Exception: | |
| pass | |
| use_gpu = torch.cuda.is_available() | |
| bf16_ok = bool(use_gpu and getattr(torch.cuda, "is_bf16_supported", lambda: False)()) | |
| dtype = torch.bfloat16 if bf16_ok else (torch.float16 if use_gpu else torch.float32) | |
| quant_cfg = None | |
| if use_gpu and HAS_BNB: | |
| quant_cfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=(torch.bfloat16 if bf16_ok else torch.float16), | |
| ) | |
| # Choose attention impl only if flash-attn is there | |
| attn_impl = "flash_attention_2" if _HAS_FLASH else None | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=(torch.bfloat16 if use_gpu else torch.float32), | |
| low_cpu_mem_usage=True, | |
| device_map="auto", | |
| quantization_config=quant_cfg, # 4-bit on GPU if available; None on CPU | |
| trust_remote_code=True, # helps for chat templates (e.g., Qwen) | |
| attn_implementation=attn_impl, # only used if flash-attn installed | |
| ).eval() | |
| # Resize if we added a pad token | |
| try: | |
| if model.get_input_embeddings().num_embeddings != len(tok): | |
| model.resize_token_embeddings(len(tok)) | |
| except Exception: | |
| pass | |
| # Prefer KV cache | |
| try: | |
| model.generation_config.use_cache = True | |
| except Exception: | |
| pass | |
| _MODEL_CACHE[model_id] = (tok, model) | |
| return tok, model | |
| def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str: | |
| sys = (system_prompt or "").strip() | |
| usr = (user_prompt or "").strip() | |
| if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): | |
| messages = [] | |
| if sys: | |
| messages.append({"role": "system", "content": sys}) | |
| messages.append({"role": "user", "content": usr}) | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else "" | |
| return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n" | |
| def _generate_microbatch(tok, model, formatted_prompts, gen_kwargs): | |
| """Generate for a list of formatted prompts. Returns (texts, tokens_out).""" | |
| device = model.device | |
| eos_ids = _all_eos_ids(tok) | |
| enc = tok( | |
| formatted_prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKENS, | |
| return_token_type_ids=False, | |
| ).to(device) | |
| prompt_lens = enc["attention_mask"].sum(dim=1) | |
| outputs = model.generate( | |
| **enc, | |
| eos_token_id=eos_ids, | |
| pad_token_id=tok.pad_token_id, | |
| **gen_kwargs, | |
| ) | |
| texts, toks_out = [], [] | |
| # Slightly faster decode (avoid extra whitespace cleanup) | |
| for i in range(outputs.size(0)): | |
| start = int(prompt_lens[i].item()) | |
| gen_ids = outputs[i, start:] | |
| text = tok.decode(gen_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip() | |
| texts.append(text) | |
| toks_out.append(int(gen_ids.numel())) | |
| return texts, toks_out | |
| def generate_batch_df( | |
| model_id: str, | |
| system_prompt: str, | |
| prompts_multiline: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| ) -> pd.DataFrame: | |
| tok, model = _load_model(model_id) | |
| # Split user inputs | |
| prompts = [p.strip() for p in (prompts_multiline or "").splitlines() if p.strip()] | |
| if not prompts: | |
| return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}]) | |
| formatted = [_format_prompt(tok, system_prompt, p) for p in prompts] | |
| # Adaptive micro-batch for latency: smaller on CPU, a bit larger on GPU | |
| B = min(len(formatted), (MICROBATCH_GPU if torch.cuda.is_available() else MICROBATCH_CPU)) | |
| # Clamp new tokens (defensive) | |
| max_new_tokens = int(max(1, min(int(max_new_tokens), MAX_NEW_TOKENS_HARD_CAP))) | |
| # Greedy is fastest; only enable sampling knobs if temperature > 0 | |
| do_sample = bool(temperature > 0.0) | |
| gen_kwargs = dict( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=float(temperature) if do_sample else None, | |
| top_p=float(top_p) if do_sample else None, | |
| top_k=int(top_k) if (do_sample and int(top_k) > 0) else None, | |
| repetition_penalty=float(repetition_penalty), | |
| num_beams=1, | |
| return_dict_in_generate=False, | |
| use_cache=True, | |
| ) | |
| all_texts, all_toks = [], [] | |
| for i in range(0, len(formatted), B): | |
| batch_prompts = formatted[i : i + B] | |
| texts, toks = _generate_microbatch(tok, model, batch_prompts, gen_kwargs) | |
| all_texts.extend(texts) | |
| all_toks.extend(toks) | |
| return pd.DataFrame({"user_prompt": prompts, "response": all_texts, "tokens_out": all_toks}) | |
| def write_csv_path(df: pd.DataFrame) -> str: | |
| ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S") | |
| tmp = tempfile.NamedTemporaryFile(prefix=f"Output_{ts}_", suffix=".csv", delete=False, dir="/tmp") | |
| df.to_csv(tmp.name, index=False) | |
| return tmp.name | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks(title="Multi-Prompt Chat") as demo: | |
| gr.Markdown( | |
| """ | |
| # Multi-Prompt Chat to test system prompt effects | |
| Pick a small model, set a **system prompt**, and enter **multiple user prompts** (one per line). | |
| Click **Generate** to get batched responses and a **downloadable CSV**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_id = gr.Dropdown( | |
| choices=DEFAULT_MODELS, | |
| value=DEFAULT_MODELS[0], | |
| label="Model", | |
| info="ZeroGPU attaches an H200 dynamically. 4-bit is used automatically on GPU when available.", | |
| ) | |
| system_prompt = gr.Textbox( | |
| label="System prompt", | |
| placeholder="e.g., You are an ecolinguistics-aware assistant...", | |
| lines=5, | |
| ) | |
| prompts_multiline = gr.Textbox( | |
| label="User prompts (one per line)", | |
| placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips\nSummarise benefits of multilingual models", | |
| lines=10, | |
| ) | |
| with gr.Accordion("Generation settings", open=False): | |
| max_new_tokens = gr.Slider(16, 1024, value=200, step=1, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="temperature (0 = greedy, fastest)") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p (used if temp > 0)") | |
| top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables; used if temp > 0)") | |
| repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty") | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| out_df = gr.Dataframe( | |
| headers=["user_prompt", "response", "tokens_out"], | |
| datatype=["str", "str", "number"], | |
| label="Results", | |
| wrap=True, | |
| interactive=False, | |
| row_count=(0, "dynamic"), | |
| type="pandas", | |
| ) | |
| csv_out = gr.File(label="CSV output", interactive=False, type="filepath") | |
| # -------- Callback: GPU-decorated for ZeroGPU -------- | |
| # <— This tells ZeroGPU to attach a GPU for this request | |
| def _generate_cb(model_id, system_prompt, prompts_multiline, | |
| max_new_tokens, temperature, top_p, top_k, repetition_penalty, | |
| progress=gr.Progress(track_tqdm=True)): | |
| progress(0.05, desc="Requesting ZeroGPU…") | |
| df = generate_batch_df( | |
| model_id=model_id, | |
| system_prompt=system_prompt, | |
| prompts_multiline=prompts_multiline, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| repetition_penalty=float(repetition_penalty), | |
| ) | |
| progress(0.95, desc="Preparing CSV…") | |
| csv_path = write_csv_path(df) | |
| progress(1.0, desc="Done") | |
| return df, csv_path | |
| run_btn.click( | |
| _generate_cb, | |
| inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty], | |
| outputs=[out_df, csv_out], | |
| api_name="generate_batch", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |