Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| import os | |
| import threading | |
| from typing import Any | |
| # Importing spaces early is recommended for ZeroGPU runtime patching. | |
| try: | |
| import spaces | |
| except Exception: | |
| class _SpacesFallback: | |
| def GPU(*args, **kwargs): | |
| def _decorator(fn): | |
| return fn | |
| return _decorator | |
| spaces = _SpacesFallback() | |
| import gradio as gr | |
| import torch | |
| import transformers | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| DEFAULT_FULL_MODEL = "NousResearch/nomos-1" | |
| DEFAULT_MODEL_CANDIDATES = "cyankiwi/nomos-1-AWQ-8bit" | |
| DEFAULT_TOKENIZER_ID = DEFAULT_FULL_MODEL | |
| GPU_DURATION_SECONDS = int(os.getenv("GPU_DURATION_SECONDS", "180")) | |
| MAX_GPU_DURATION_SECONDS = int(os.getenv("MAX_GPU_DURATION_SECONDS", "300")) | |
| GPU_SIZE = os.getenv("GPU_SIZE", "large").strip().lower() or "large" | |
| MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "2048")) | |
| MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256")) | |
| TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true" | |
| PREFER_FULL = os.getenv("PREFER_FULL", "false").lower() == "true" | |
| TOKENIZER_ID = os.getenv("TOKENIZER_ID", DEFAULT_TOKENIZER_ID).strip() or DEFAULT_TOKENIZER_ID | |
| TORCH_DTYPE = os.getenv("TORCH_DTYPE", "bfloat16").strip().lower() | |
| MODEL_DEVICE_MAP = os.getenv("MODEL_DEVICE_MAP", "auto").strip() or "auto" | |
| _MODEL_LOCK = threading.Lock() | |
| _MODEL: Any = None | |
| _TOKENIZER: Any = None | |
| _MODEL_ID: str | None = None | |
| _LOAD_ERRORS: list[str] = [] | |
| def _ordered_candidates() -> list[str]: | |
| configured = os.getenv("MODEL_CANDIDATES", DEFAULT_MODEL_CANDIDATES) | |
| candidates = [m.strip() for m in configured.split(",") if m.strip()] | |
| if PREFER_FULL and DEFAULT_FULL_MODEL not in candidates: | |
| candidates = [DEFAULT_FULL_MODEL] + candidates | |
| return candidates | |
| def _torch_dtype() -> torch.dtype | str: | |
| if TORCH_DTYPE in {"", "auto"}: | |
| return "auto" | |
| if TORCH_DTYPE in {"bfloat16", "bf16"}: | |
| return torch.bfloat16 | |
| if TORCH_DTYPE in {"float16", "fp16", "half"}: | |
| return torch.float16 | |
| if TORCH_DTYPE in {"float32", "fp32"}: | |
| return torch.float32 | |
| return "auto" | |
| def _package_versions() -> str: | |
| pieces = [ | |
| f"torch={torch.__version__}", | |
| f"transformers={transformers.__version__}", | |
| ] | |
| try: | |
| import compressed_tensors | |
| pieces.append(f"compressed-tensors={compressed_tensors.__version__}") | |
| except Exception as exc: # pragma: no cover - environment specific | |
| pieces.append(f"compressed-tensors=unavailable({type(exc).__name__})") | |
| return ", ".join(pieces) | |
| def _cuda_status() -> str: | |
| if not torch.cuda.is_available(): | |
| return "CUDA unavailable" | |
| try: | |
| idx = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(idx) | |
| total_gb = props.total_memory / (1024**3) | |
| return f"{props.name} ({total_gb:.1f} GB)" | |
| except Exception as exc: # pragma: no cover - environment specific | |
| return f"CUDA available (details unavailable: {type(exc).__name__})" | |
| def _load_model_if_needed() -> tuple[str | None, str]: | |
| global _MODEL, _TOKENIZER, _MODEL_ID | |
| if _MODEL is not None and _TOKENIZER is not None and _MODEL_ID is not None: | |
| return _MODEL_ID, "model already loaded" | |
| with _MODEL_LOCK: | |
| if _MODEL is not None and _TOKENIZER is not None and _MODEL_ID is not None: | |
| return _MODEL_ID, "model already loaded" | |
| errors: list[str] = [] | |
| for candidate in _ordered_candidates(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| TOKENIZER_ID, | |
| trust_remote_code=TRUST_REMOTE_CODE, | |
| ) | |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| dtype = _torch_dtype() | |
| model_kwargs: dict[str, Any] = { | |
| "trust_remote_code": TRUST_REMOTE_CODE, | |
| "low_cpu_mem_usage": True, | |
| "device_map": MODEL_DEVICE_MAP, | |
| } | |
| if dtype != "auto": | |
| model_kwargs["torch_dtype"] = dtype | |
| model = AutoModelForCausalLM.from_pretrained(candidate, **model_kwargs) | |
| model.eval() | |
| _TOKENIZER = tokenizer | |
| _MODEL = model | |
| _MODEL_ID = candidate | |
| _LOAD_ERRORS.clear() | |
| return candidate, "loaded" | |
| except Exception as exc: | |
| errors.append(f"{candidate}: {type(exc).__name__}: {exc}") | |
| _LOAD_ERRORS[:] = errors | |
| return None, "load failed" | |
| def _status_text() -> str: | |
| candidates = ", ".join(_ordered_candidates()) | |
| loaded = _MODEL_ID or "none" | |
| base = ( | |
| f"Loaded model: `{loaded}`\n\n" | |
| f"Tokenizer: `{TOKENIZER_ID}`\n\n" | |
| f"Torch dtype: `{TORCH_DTYPE}` | Device map: `{MODEL_DEVICE_MAP}`\n\n" | |
| f"GPU size: `{GPU_SIZE}` | Duration default: `{GPU_DURATION_SECONDS}s`\n\n" | |
| f"Max input tokens: `{MAX_INPUT_TOKENS}`\n\n" | |
| f"Candidates: `{candidates}`\n\n" | |
| f"Runtime: `{_cuda_status()}`\n\n" | |
| f"Packages: `{_package_versions()}`" | |
| ) | |
| if _LOAD_ERRORS: | |
| err = "\n".join(f"- {e}" for e in _LOAD_ERRORS[-3:]) | |
| return base + "\n\nRecent load errors:\n" + err | |
| return base | |
| def _duration_for_generate( | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| do_sample: bool, | |
| ) -> int: | |
| del prompt, temperature, top_p, top_k, do_sample | |
| try: | |
| requested_new = int(max_new_tokens) | |
| except Exception: | |
| requested_new = MAX_NEW_TOKENS_DEFAULT | |
| est = max(GPU_DURATION_SECONDS, 60 + int(0.8 * max(32, requested_new))) | |
| return min(MAX_GPU_DURATION_SECONDS, est) | |
| def _gpu_decorator(): | |
| try: | |
| return spaces.GPU(duration=_duration_for_generate, size=GPU_SIZE) | |
| except TypeError: | |
| return spaces.GPU(duration=_duration_for_generate) | |
| def generate( | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| do_sample: bool, | |
| ) -> tuple[str, str]: | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| return "Provide a prompt.", _status_text() | |
| model_id, _ = _load_model_if_needed() | |
| if model_id is None: | |
| return "Model load failed. Check status and Space logs.", _status_text() | |
| tokenizer = _TOKENIZER | |
| model = _MODEL | |
| messages = [{"role": "user", "content": prompt}] | |
| chat_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| model_inputs = tokenizer(chat_text, return_tensors="pt") | |
| try: | |
| device = next(model.parameters()).device | |
| except Exception: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| for k, v in list(model_inputs.items()): | |
| if torch.is_tensor(v): | |
| model_inputs[k] = v.to(device) | |
| input_ids = model_inputs.get("input_ids") | |
| if torch.is_tensor(input_ids) and input_ids.ndim == 2 and input_ids.shape[-1] > MAX_INPUT_TOKENS: | |
| trim = input_ids.shape[-1] - MAX_INPUT_TOKENS | |
| for k, v in list(model_inputs.items()): | |
| if torch.is_tensor(v) and v.ndim == 2 and v.shape[-1] == input_ids.shape[-1]: | |
| model_inputs[k] = v[:, trim:] | |
| input_ids = model_inputs["input_ids"] | |
| generation_cfg = getattr(model, "generation_config", None) | |
| eos_token_id = getattr(generation_cfg, "eos_token_id", None) | |
| pad_token_id = getattr(generation_cfg, "pad_token_id", None) | |
| if pad_token_id is None: | |
| pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id or 0 | |
| gen_kwargs: dict[str, Any] = { | |
| **model_inputs, | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": bool(do_sample), | |
| "pad_token_id": pad_token_id, | |
| } | |
| if eos_token_id is not None: | |
| gen_kwargs["eos_token_id"] = eos_token_id | |
| if do_sample: | |
| gen_kwargs.update( | |
| { | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "top_k": int(top_k), | |
| } | |
| ) | |
| with torch.no_grad(): | |
| output_ids = model.generate(**gen_kwargs) | |
| prompt_len = input_ids.shape[-1] if torch.is_tensor(input_ids) else 0 | |
| generated_ids = output_ids[0][prompt_len:] | |
| text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| if not text: | |
| text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| return text, _status_text() | |
| with gr.Blocks(title="Nomos ZeroGPU Inference") as demo: | |
| gr.Markdown( | |
| "# Nomos Remote Inference (ZeroGPU)\n" | |
| "This app tries model candidates in order and keeps the first that loads." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| lines=10, | |
| placeholder="Ask for a concise proof or solution sketch...", | |
| ) | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider( | |
| minimum=32, | |
| maximum=1024, | |
| value=MAX_NEW_TOKENS_DEFAULT, | |
| step=1, | |
| label="Max new tokens", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=20, | |
| step=1, | |
| label="Top-k", | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.5, | |
| value=0.6, | |
| step=0.01, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.05, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01, | |
| label="Top-p", | |
| ) | |
| do_sample = gr.Checkbox(value=True, label="Sample") | |
| run_btn = gr.Button("Generate") | |
| with gr.Column(scale=2): | |
| output = gr.Textbox(label="Output", lines=18) | |
| status = gr.Markdown(value=_status_text()) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[prompt, max_new_tokens, temperature, top_p, top_k, do_sample], | |
| outputs=[output, status], | |
| api_name="generate", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Solve: Find all integers n such that n^2 + n + 1 is prime."], | |
| ["Give a proof sketch that there are infinitely many primes."], | |
| ], | |
| inputs=prompt, | |
| ) | |
| demo.queue(max_size=32) | |
| if __name__ == "__main__": | |
| demo.launch() | |