Spaces:
Runtime error
Runtime error
| # app.py | |
| import time | |
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria | |
| MODEL_NAME = "Rapnss/VIA-01" # your HF repo | |
| # Configs you can tune | |
| DEFAULT_MAX_NEW_TOKENS = 64 # keep low to meet latency targets | |
| MAX_PROMPT_TOKENS = 512 # truncate long prompts | |
| TEMPERATURE = 0.3 | |
| TOP_P = 0.9 | |
| DO_SAMPLE = False # deterministic and usually faster than sampling | |
| NUM_BEAMS = 1 # beam=1 is fastest | |
| WARMUP_PROMPT = "Hello." # used to warm model after loading | |
| # Try to load tokenizer / model in quantized mode (4-bit) if bitsandbytes available | |
| print("Loading tokenizer & model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| model = None | |
| device = "cpu" | |
| try: | |
| # If CUDA is available and bitsandbytes exists, load 4-bit | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| print("CUDA available — attempting 4-bit load with bitsandbytes...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| load_in_4bit=True, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, # some user repos need it | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| else: | |
| raise RuntimeError("CUDA not available; load fallback") | |
| except Exception as e: | |
| print("4-bit load failed or not available:", e) | |
| print("Falling back to fp16/cpu (best-effort).") | |
| # fallback: try fp16 on GPU or float32 on CPU | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| device = "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| device_map={"": "cpu"}, | |
| trust_remote_code=True, | |
| ) | |
| # Put model to eval & optionally compile | |
| model.eval() | |
| # Optional: try torch.compile for small speedups (PyTorch 2.x only, may increase startup) | |
| try: | |
| if torch.__version__.startswith("2"): | |
| print("Attempting torch.compile(model) for runtime speedups...") | |
| model = torch.compile(model) | |
| except Exception as e: | |
| print("torch.compile not used:", e) | |
| print(f"Model loaded on {device}") | |
| # Utility: fast tokenize + move to proper device | |
| def prepare_inputs(prompt_text): | |
| # Truncate long prompts to limit total tokens on generation | |
| inputs = tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_PROMPT_TOKENS, | |
| padding=False, | |
| ) | |
| if device == "cuda": | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| return inputs | |
| # Optional: short stopping criteria example (stop on newline double) | |
| class StopOnDoubleNewline(StoppingCriteria): | |
| def __call__(self, input_ids, scores, **kwargs): | |
| # stop when last two tokens are newline + newline (customize as needed) | |
| if input_ids.shape[-1] >= 2: | |
| if input_ids[0, -2].item() == tokenizer.eos_token_id or input_ids[0, -1].item() == tokenizer.eos_token_id: | |
| return True | |
| return False | |
| stop_criteria = StoppingCriteriaList([StopOnDoubleNewline()]) | |
| # Warm-up function (to run a single tiny generation so the model caches kernels) | |
| def warm_up_model(): | |
| try: | |
| prompt = WARMUP_PROMPT | |
| inputs = prepare_inputs(prompt) | |
| with torch.inference_mode(): | |
| model.generate( | |
| **inputs, | |
| max_new_tokens=8, | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| print("Warmup complete.") | |
| except Exception as e: | |
| print("Warmup failed:", e) | |
| # Warm up once at startup to reduce first-request latency | |
| warm_up_model() | |
| # The actual chat function used by Gradio | |
| def chat_fn(prompt: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = TEMPERATURE): | |
| t0 = time.time() | |
| prompt = prompt.strip() | |
| if not prompt: | |
| return "Please enter a prompt." | |
| # safety: clamp max_new_tokens to avoid huge generations | |
| max_new_tokens = int(max(1, min(max_new_tokens, 256))) | |
| inputs = prepare_inputs(prompt) | |
| # Generation arguments tuned for speed | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=float(temperature), | |
| top_p=float(TOP_P), | |
| do_sample=DO_SAMPLE, | |
| num_beams=NUM_BEAMS, | |
| eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id, | |
| pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, | |
| use_cache=True, | |
| early_stopping=True, | |
| # stopping_criteria=stop_criteria, # enable if you want custom stopping | |
| ) | |
| # Inference context to reduce overhead | |
| with torch.inference_mode(): | |
| outputs = model.generate(**gen_kwargs) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| latency = time.time() - t0 | |
| # Return response and latency for debugging | |
| return f"{response}\n\n---\nLatency: {latency:.2f}s (max_new_tokens={max_new_tokens}, device={device})" | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Rapnss VIA-01") | |
| with gr.Row(): | |
| txt = gr.Textbox(lines=3, placeholder="Ask VIA-01 something...", label="Prompt") | |
| with gr.Row(): | |
| max_tokens = gr.Slider(16, 256, value=DEFAULT_MAX_NEW_TOKENS, step=16, label="Max new tokens") | |
| temp = gr.Slider(0.0, 1.0, value=TEMPERATURE, step=0.05, label="Temperature") | |
| out = gr.Textbox(label="VIA-01 Response", lines=12) | |
| submit = gr.Button("Generate") | |
| submit.click(fn=chat_fn, inputs=[txt, max_tokens, temp], outputs=out) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |