# app.py — ZeroGPU-optimised Gradio app (HF Spaces) — refined import os import tempfile from datetime import datetime import gradio as gr import pandas as pd import torch from transformers import AutoModelForCausalLM, AutoTokenizer # ---- Small env tweak: faster hub downloads when available ---- os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") # ---- ZeroGPU decorator ---- try: import spaces # HF Spaces utility (provides @spaces.GPU()) except Exception: class _Noop: def GPU(self, *args, **kwargs): def deco(fn): return fn return deco spaces = _Noop() # ---- Optional quantisation (GPU only) ---- try: from transformers import BitsAndBytesConfig HAS_BNB = True except Exception: HAS_BNB = False # ---- Optional Flash-Attention 2 ---- _HAS_FLASH = False try: import flash_attn # noqa: F401 _HAS_FLASH = True except Exception: _HAS_FLASH = False # ---------------------------- # Config # ---------------------------- DEFAULT_MODELS = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "neovalle/tinyllama-1.1B-h4rmony-trained", ] # Keep batches reasonable on ZeroGPU for low latency MICROBATCH_CPU = 2 MICROBATCH_GPU = 6 # H200 can handle a bit more than 4 for tiny models # Cap encoder length to avoid wasting time on very long inputs MAX_INPUT_TOKENS = 1024 MAX_NEW_TOKENS_HARD_CAP = 1024 # extra guardrail # Speed on GPU (TF32 gives extra throughput on Ampere+) if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # hint PyTorch to pick faster kernels when legal try: torch.set_float32_matmul_precision("high") except Exception: pass else: # On CPU, reducing threads sometimes helps stability/predictability try: torch.set_num_threads(max(1, (os.cpu_count() or 4) // 2)) except Exception: pass _MODEL_CACHE = {} # cache: model_id -> (tokenizer, model) # ---------------------------- # Helpers # ---------------------------- def _all_eos_ids(tok): """Collect a few likely EOS ids so generation can stop earlier.""" ids = set() if tok.eos_token_id is not None: ids.add(tok.eos_token_id) for t in ("<|im_end|>", "<|endoftext|>", ""): try: tid = tok.convert_tokens_to_ids(t) if isinstance(tid, int) and tid >= 0: ids.add(tid) except Exception: pass return list(ids) if ids else None def _load_model(model_id: str): """Load & cache model/tokenizer. On GPU, prefer 4-bit NF4 with BF16 compute.""" if model_id in _MODEL_CACHE: return _MODEL_CACHE[model_id] tok = AutoTokenizer.from_pretrained(model_id, use_fast=True) # Tokenizer hygiene if tok.pad_token is None: if tok.eos_token is not None: tok.pad_token = tok.eos_token else: tok.add_special_tokens({"pad_token": "<|pad|>"}) # Left padding plays nicer with causal models and kv-cache in batched gen try: tok.padding_side = "left" except Exception: pass use_gpu = torch.cuda.is_available() bf16_ok = bool(use_gpu and getattr(torch.cuda, "is_bf16_supported", lambda: False)()) dtype = torch.bfloat16 if bf16_ok else (torch.float16 if use_gpu else torch.float32) quant_cfg = None if use_gpu and HAS_BNB: quant_cfg = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=(torch.bfloat16 if bf16_ok else torch.float16), ) # Choose attention impl only if flash-attn is there attn_impl = "flash_attention_2" if _HAS_FLASH else None model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=(torch.bfloat16 if use_gpu else torch.float32), low_cpu_mem_usage=True, device_map="auto", quantization_config=quant_cfg, # 4-bit on GPU if available; None on CPU trust_remote_code=True, # helps for chat templates (e.g., Qwen) attn_implementation=attn_impl, # only used if flash-attn installed ).eval() # Resize if we added a pad token try: if model.get_input_embeddings().num_embeddings != len(tok): model.resize_token_embeddings(len(tok)) except Exception: pass # Prefer KV cache try: model.generation_config.use_cache = True except Exception: pass _MODEL_CACHE[model_id] = (tok, model) return tok, model def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str: sys = (system_prompt or "").strip() usr = (user_prompt or "").strip() if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): messages = [] if sys: messages.append({"role": "system", "content": sys}) messages.append({"role": "user", "content": usr}) return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) prefix = f"<>\n{sys}\n<>\n\n" if sys else "" return f"{prefix}<>\n{usr}\n<>\n<>\n" @torch.inference_mode() def _generate_microbatch(tok, model, formatted_prompts, gen_kwargs): """Generate for a list of formatted prompts. Returns (texts, tokens_out).""" device = model.device eos_ids = _all_eos_ids(tok) enc = tok( formatted_prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKENS, return_token_type_ids=False, ).to(device) prompt_lens = enc["attention_mask"].sum(dim=1) outputs = model.generate( **enc, eos_token_id=eos_ids, pad_token_id=tok.pad_token_id, **gen_kwargs, ) texts, toks_out = [], [] # Slightly faster decode (avoid extra whitespace cleanup) for i in range(outputs.size(0)): start = int(prompt_lens[i].item()) gen_ids = outputs[i, start:] text = tok.decode(gen_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip() texts.append(text) toks_out.append(int(gen_ids.numel())) return texts, toks_out def generate_batch_df( model_id: str, system_prompt: str, prompts_multiline: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, ) -> pd.DataFrame: tok, model = _load_model(model_id) # Split user inputs prompts = [p.strip() for p in (prompts_multiline or "").splitlines() if p.strip()] if not prompts: return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}]) formatted = [_format_prompt(tok, system_prompt, p) for p in prompts] # Adaptive micro-batch for latency: smaller on CPU, a bit larger on GPU B = min(len(formatted), (MICROBATCH_GPU if torch.cuda.is_available() else MICROBATCH_CPU)) # Clamp new tokens (defensive) max_new_tokens = int(max(1, min(int(max_new_tokens), MAX_NEW_TOKENS_HARD_CAP))) # Greedy is fastest; only enable sampling knobs if temperature > 0 do_sample = bool(temperature > 0.0) gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=float(temperature) if do_sample else None, top_p=float(top_p) if do_sample else None, top_k=int(top_k) if (do_sample and int(top_k) > 0) else None, repetition_penalty=float(repetition_penalty), num_beams=1, return_dict_in_generate=False, use_cache=True, ) all_texts, all_toks = [], [] for i in range(0, len(formatted), B): batch_prompts = formatted[i : i + B] texts, toks = _generate_microbatch(tok, model, batch_prompts, gen_kwargs) all_texts.extend(texts) all_toks.extend(toks) return pd.DataFrame({"user_prompt": prompts, "response": all_texts, "tokens_out": all_toks}) def write_csv_path(df: pd.DataFrame) -> str: ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S") tmp = tempfile.NamedTemporaryFile(prefix=f"Output_{ts}_", suffix=".csv", delete=False, dir="/tmp") df.to_csv(tmp.name, index=False) return tmp.name # ---------------------------- # Gradio UI # ---------------------------- with gr.Blocks(title="Multi-Prompt Chat") as demo: gr.Markdown( """ # Multi-Prompt Chat to test system prompt effects Pick a small model, set a **system prompt**, and enter **multiple user prompts** (one per line). Click **Generate** to get batched responses and a **downloadable CSV**. """ ) with gr.Row(): with gr.Column(scale=1): model_id = gr.Dropdown( choices=DEFAULT_MODELS, value=DEFAULT_MODELS[0], label="Model", info="ZeroGPU attaches an H200 dynamically. 4-bit is used automatically on GPU when available.", ) system_prompt = gr.Textbox( label="System prompt", placeholder="e.g., You are an ecolinguistics-aware assistant...", lines=5, ) prompts_multiline = gr.Textbox( label="User prompts (one per line)", placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips\nSummarise benefits of multilingual models", lines=10, ) with gr.Accordion("Generation settings", open=False): max_new_tokens = gr.Slider(16, 1024, value=200, step=1, label="max_new_tokens") temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="temperature (0 = greedy, fastest)") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p (used if temp > 0)") top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables; used if temp > 0)") repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty") run_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): out_df = gr.Dataframe( headers=["user_prompt", "response", "tokens_out"], datatype=["str", "str", "number"], label="Results", wrap=True, interactive=False, row_count=(0, "dynamic"), type="pandas", ) csv_out = gr.File(label="CSV output", interactive=False, type="filepath") # -------- Callback: GPU-decorated for ZeroGPU -------- @spaces.GPU() # <— This tells ZeroGPU to attach a GPU for this request def _generate_cb(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty, progress=gr.Progress(track_tqdm=True)): progress(0.05, desc="Requesting ZeroGPU…") df = generate_batch_df( model_id=model_id, system_prompt=system_prompt, prompts_multiline=prompts_multiline, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), top_k=int(top_k), repetition_penalty=float(repetition_penalty), ) progress(0.95, desc="Preparing CSV…") csv_path = write_csv_path(df) progress(1.0, desc="Done") return df, csv_path run_btn.click( _generate_cb, inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[out_df, csv_out], api_name="generate_batch", ) if __name__ == "__main__": demo.launch()