Spaces:
Running on Zero
Running on Zero
| import os | |
| import json | |
| import re | |
| from typing import List, Tuple | |
| from threading import Thread, Lock | |
| is_hf_space = os.getenv("SPACE_ID") is not None | |
| if is_hf_space: | |
| import spaces | |
| else: | |
| class _NoopSpaces: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _NoopSpaces() | |
| import gradio as gr | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, | |
| ) | |
| repo_id = os.getenv("MODEL_REPO_ID", "anton-hugging/TimeOmni-1-7B") | |
| max_new_tokens_limit = min(max(int(os.getenv("MAX_NEW_TOKENS_LIMIT", "4096")), 1), 4096) | |
| device_map = os.getenv("DEVICE_MAP", "auto") | |
| attn_implementation = os.getenv("ATTN_IMPLEMENTATION", "sdpa") | |
| cache_dir = os.getenv("MODEL_CACHE_DIR", "./model_weights") | |
| torch_dtype_name = os.getenv("TORCH_DTYPE", "bfloat16").strip().lower() | |
| default_timeout = 300.0 | |
| print("=" * 60) | |
| print(f"Configured model repository: {repo_id}") | |
| print(f"Environment: {'HF Space' if is_hf_space else 'Local'}") | |
| print(f"Default timeout: {default_timeout}s") | |
| print(f"Device map: {device_map}") | |
| print(f"Attention implementation: {attn_implementation}") | |
| print(f"Max new tokens limit: {max_new_tokens_limit}") | |
| print("=" * 60) | |
| matplotlib.use("Agg") | |
| local_model_path = os.path.expanduser(os.getenv("MODEL_PATH", "")).strip() | |
| model_source = repo_id | |
| effective_cache_dir = cache_dir | |
| if local_model_path and os.path.exists(local_model_path): | |
| model_source = local_model_path | |
| effective_cache_dir = None | |
| print(f"Using local model folder at: {model_source}") | |
| else: | |
| print(f"Using Hugging Face model repo: {repo_id}") | |
| print(f"Model cache path: {cache_dir}") | |
| def resolve_torch_dtype(dtype_name: str) -> torch.dtype: | |
| dtype_map = { | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| } | |
| return dtype_map.get(dtype_name, torch.bfloat16) | |
| class RuntimeStopCriteria(StoppingCriteria): | |
| def __init__(self, stop_state, start_time: float, timeout_threshold: float): | |
| self.stop_state = stop_state | |
| self.start_time = start_time | |
| self.timeout_threshold = timeout_threshold | |
| def __call__(self, input_ids, scores, **kwargs) -> bool: | |
| import time | |
| if isinstance(self.stop_state, dict) and self.stop_state.get("stop"): | |
| return True | |
| elapsed = time.time() - self.start_time | |
| return elapsed > self.timeout_threshold | |
| tokenizer = None | |
| model = None | |
| model_init_lock = Lock() | |
| def load_model_if_needed(): | |
| global tokenizer, model | |
| if model is not None and tokenizer is not None: | |
| return | |
| with model_init_lock: | |
| if model is not None and tokenizer is not None: | |
| return | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_source, | |
| cache_dir=effective_cache_dir, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_source, | |
| cache_dir=effective_cache_dir, | |
| dtype=resolve_torch_dtype(torch_dtype_name), | |
| device_map=device_map, | |
| low_cpu_mem_usage=True, | |
| attn_implementation=attn_implementation, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| print("=" * 60) | |
| print("Model loaded successfully!") | |
| print(f"Model source: {model_source}") | |
| print("=" * 60) | |
| def _first_key(row, keys): | |
| for key in keys: | |
| if key in row and row[key]: | |
| return row[key] | |
| return None | |
| def load_task_prompts(max_tasks: int = 4) -> List[Tuple[str, str, str]]: | |
| dataset_id = "anton-hugging/timeomni-1-testbed" | |
| data_file = "id_test.json" | |
| task_keys = ["task_type"] | |
| system_keys = ["system"] | |
| problem_keys = ["problem"] | |
| tasks: List[Tuple[str, str, str]] = [] | |
| seen = set() | |
| cache_dir = "./dataset_cache" | |
| local_dataset_path = os.path.expanduser(os.getenv("DATASET_FILE", "")) | |
| try: | |
| if local_dataset_path and os.path.exists(local_dataset_path): | |
| dataset_path = local_dataset_path | |
| else: | |
| dataset_path = hf_hub_download( | |
| repo_id=dataset_id, | |
| filename=data_file, | |
| repo_type="dataset", | |
| local_dir=cache_dir, | |
| ) | |
| with open(dataset_path, "r", encoding="utf-8") as f: | |
| ds = json.load(f) | |
| except Exception: | |
| ds = None | |
| if ds is not None: | |
| for row in ds: | |
| task = _first_key(row, task_keys) | |
| system_prompt = _first_key(row, system_keys) | |
| problem = _first_key(row, problem_keys) | |
| if task and task not in seen: | |
| if not system_prompt: | |
| system_prompt = f"Task: {task}\nFollow the dataset setting for your response." | |
| tasks.append((task, system_prompt, problem or "")) | |
| seen.add(task) | |
| if len(tasks) >= max_tasks: | |
| break | |
| while len(tasks) < max_tasks: | |
| idx = len(tasks) + 1 | |
| tasks.append((f"Task {idx}", "Follow the task instructions for your response.", "")) | |
| return tasks | |
| def extract_series(text: str) -> List[List[float]]: | |
| matches = re.findall(r"\[([^\[\]]+)\]", text) | |
| series: List[List[float]] = [] | |
| for m in matches: | |
| parts = re.split(r"[,\s]+", m.strip()) | |
| nums: List[float] = [] | |
| for p in parts: | |
| if not p: | |
| continue | |
| try: | |
| nums.append(float(p)) | |
| except ValueError: | |
| nums = [] | |
| break | |
| if nums: | |
| series.append(nums) | |
| return series | |
| def build_plot(series: List[List[float]]): | |
| if not series: | |
| return None | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| for idx, s in enumerate(series): | |
| ax.plot(range(len(s)), s, label=f"S{idx + 1}") | |
| ax.set_xlabel("t") | |
| ax.set_ylabel("value") | |
| ax.grid(True, alpha=0.3) | |
| if len(series) > 1: | |
| ax.legend() | |
| fig.tight_layout() | |
| return fig | |
| def tuples_to_messages(history): | |
| if not history or not isinstance(history, list): | |
| return [] | |
| messages = [] | |
| for item in history: | |
| if isinstance(item, (list, tuple)) and len(item) == 2: | |
| user_msg, assistant_msg = item | |
| messages.append({"role": "user", "content": str(user_msg or "")}) | |
| messages.append({"role": "assistant", "content": str(assistant_msg or "")}) | |
| elif isinstance(item, dict) and "role" in item and "content" in item: | |
| messages.append({"role": item["role"], "content": str(item.get("content", ""))}) | |
| return messages | |
| def format_task_label(task_label: str) -> str: | |
| return task_label.replace("_", " ").title() | |
| def respond_core( | |
| system_message, | |
| message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| stop_state, | |
| plot, | |
| ): | |
| import time | |
| load_model_if_needed() | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": message} | |
| ] | |
| try: | |
| start_time = time.time() | |
| print(f"[respond_core] Starting generation, max_tokens={max_tokens}") | |
| timeout_threshold = float(os.getenv("GENERATION_TIMEOUT", str(default_timeout))) | |
| safe_max_tokens = min(max(int(max_tokens), 1), max_new_tokens_limit) | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer(prompt_text, return_tensors="pt") | |
| input_device = next(model.parameters()).device | |
| inputs = {k: v.to(input_device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| stopping_criteria = StoppingCriteriaList( | |
| [RuntimeStopCriteria(stop_state, start_time, timeout_threshold)] | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "max_new_tokens": safe_max_tokens, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "do_sample": True, | |
| "streamer": streamer, | |
| "stopping_criteria": stopping_criteria, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| "use_cache": True, | |
| } | |
| generation_thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| generation_thread.daemon = True | |
| generation_thread.start() | |
| response = "" | |
| chunk_count = 0 | |
| last_yield_time = start_time | |
| for token in streamer: | |
| current_time = time.time() | |
| elapsed = current_time - start_time | |
| if isinstance(stop_state, dict) and stop_state.get("stop"): | |
| print(f"[respond_core] Generation stopped by user") | |
| break | |
| if elapsed > timeout_threshold: | |
| print(f"[respond_core] Timeout threshold reached ({elapsed:.2f}s), stopping generation") | |
| if response: | |
| yield response, plot | |
| break | |
| if token: | |
| response += token | |
| chunk_count += 1 | |
| if chunk_count == 1: | |
| first_token_time = current_time - start_time | |
| print(f"[respond_core] First token received in {first_token_time:.2f}s, length={len(response)}") | |
| yield response, plot | |
| last_yield_time = current_time | |
| elif current_time - last_yield_time > 5.0: | |
| if response: | |
| yield response, plot | |
| last_yield_time = current_time | |
| generation_thread.join(timeout=1.0) | |
| print(f"[respond_core] Generation completed, total tokens={chunk_count}, response length={len(response)}") | |
| except Exception as e: | |
| import traceback | |
| print(f"[respond_core] Error: {e}") | |
| print(f"[respond_core] Traceback: {traceback.format_exc()}") | |
| yield "", plot | |
| def build_card(task_label: str, system_prompt: str, example_question: str = "", allow_edit_system: bool = False): | |
| with gr.Tab(task_label): | |
| if allow_edit_system: | |
| system_prompt_input = gr.Textbox( | |
| label="System Prompt", | |
| value=system_prompt, | |
| lines=5, | |
| placeholder="Enter your custom system prompt...", | |
| ) | |
| current_system_prompt = gr.State(system_prompt) | |
| system_prompt_input.change( | |
| lambda x: x, | |
| inputs=[system_prompt_input], | |
| outputs=[current_system_prompt], | |
| ) | |
| else: | |
| with gr.Accordion("Default system prompt", open=True): | |
| gr.Markdown(f"```\n{system_prompt}\n```") | |
| current_system_prompt = gr.State(system_prompt) | |
| plot = gr.Plot(label="Time Series Plot") | |
| chatbot = gr.Chatbot() | |
| chatbot_state = gr.State([]) | |
| stop_state = gr.State({"stop": False}) | |
| msg = gr.Textbox( | |
| label="Message", | |
| lines=3, | |
| placeholder="Type a message...", | |
| value=example_question or "", | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=max_new_tokens_limit, | |
| value=max_new_tokens_limit, | |
| step=1, | |
| label="Max new tokens", | |
| ) | |
| temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.001, maximum=1.0, value=0.001, step=0.05, label="Top-p") | |
| send = gr.Button("Send") | |
| stop = gr.Button("Stop Generation") | |
| def respond(message, chatbot_state, max_tokens, temperature, top_p, stop_state, system_prompt_state): | |
| import time | |
| start_time = time.time() | |
| print(f"[respond] Request received, message length={len(message) if message else 0}") | |
| if not message or not message.strip(): | |
| current_chat = tuples_to_messages(chatbot_state) if chatbot_state else [] | |
| if not isinstance(current_chat, list): | |
| current_chat = [] | |
| return current_chat, chatbot_state, None, stop_state | |
| current_prompt = system_prompt_state or system_prompt | |
| if not isinstance(stop_state, dict): | |
| stop_state = {"stop": False} | |
| else: | |
| stop_state = stop_state.copy() | |
| stop_state["stop"] = False | |
| series = extract_series(message) | |
| plot = build_plot(series) | |
| plot_time = time.time() - start_time | |
| print(f"[respond] Plot generated in {plot_time:.2f}s") | |
| current_chat = tuples_to_messages(chatbot_state) if chatbot_state else [] | |
| if not isinstance(current_chat, list): | |
| current_chat = [] | |
| yield current_chat, chatbot_state, plot, stop_state | |
| initial_response = [(message, "Generating answer...")] | |
| initial_chat = tuples_to_messages(initial_response) | |
| yield initial_chat, chatbot_state, plot, stop_state | |
| stream = respond_core( | |
| current_prompt, | |
| message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| stop_state, | |
| plot, | |
| ) | |
| first_token_time = None | |
| has_yielded_content = False | |
| last_response = "" | |
| for response, current_plot in stream: | |
| if response: | |
| last_response = response | |
| if response.strip() and response != "Generating answer...": | |
| if first_token_time is None: | |
| first_token_time = time.time() - start_time | |
| print(f"[respond] First token received in {first_token_time:.2f}s") | |
| updated = [(message, response)] | |
| chat_value = tuples_to_messages(updated) | |
| if not isinstance(chat_value, list): | |
| chat_value = [] | |
| yield chat_value, updated, current_plot, stop_state | |
| has_yielded_content = True | |
| if not has_yielded_content: | |
| elapsed = time.time() - start_time | |
| is_hf_space = os.getenv("SPACE_ID") is not None | |
| default_timeout = 300.0 | |
| timeout_threshold = float(os.getenv("GENERATION_TIMEOUT", str(default_timeout))) | |
| if elapsed > timeout_threshold: | |
| error_msg = f"Generation timeout ({elapsed:.1f}s). Please try reducing input length or max_tokens (<= {max_new_tokens_limit}), or increase GENERATION_TIMEOUT environment variable (current: {timeout_threshold}s)." | |
| else: | |
| error_msg = "Generation failed, please check logs or retry." | |
| error_response = [(message, error_msg)] | |
| error_chat = tuples_to_messages(error_response) | |
| yield error_chat, chatbot_state, plot, stop_state | |
| print(f"[respond] No content generated, elapsed={elapsed:.2f}s") | |
| total_time = time.time() - start_time | |
| print(f"[respond] Request completed in {total_time:.2f}s") | |
| send_event = send.click( | |
| respond, | |
| inputs=[msg, chatbot_state, max_tokens, temperature, top_p, stop_state, current_system_prompt], | |
| outputs=[chatbot, chatbot_state, plot, stop_state], | |
| ) | |
| msg_event = msg.submit( | |
| respond, | |
| inputs=[msg, chatbot_state, max_tokens, temperature, top_p, stop_state, current_system_prompt], | |
| outputs=[chatbot, chatbot_state, plot, stop_state], | |
| ) | |
| stop.click(lambda: {"stop": True}, None, stop_state, cancels=[send_event, msg_event]) | |
| send_event.then(lambda: "", None, msg) | |
| msg_event.then(lambda: "", None, msg) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🐏 [TimeOmni-1](https://huggingface.co/anton-hugging/TimeOmni-1-7B)") | |
| tasks = load_task_prompts() | |
| task_map = {label: (prompt, example) for label, prompt, example in tasks} | |
| task_order = [ | |
| "scenario_understanding", | |
| "causality_discovery", | |
| "event_aware_forecasting", | |
| "decision_making", | |
| ] | |
| emojis = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣"] | |
| with gr.Tabs(): | |
| for idx, label in enumerate(task_order): | |
| if label in task_map: | |
| tab_label = f"{emojis[idx]} {format_task_label(label)}" | |
| system_prompt, example_question = task_map[label] | |
| build_card(tab_label, system_prompt, example_question, allow_edit_system=False) | |
| build_card( | |
| f"{emojis[4]} Your Custom Question", | |
| "You are a helpful assistant. Output Format:\n<think>Your step-by-step reasoning process that justifies your answer</think>\n<answer>Your final answer</answer>", | |
| allow_edit_system=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue( | |
| default_concurrency_limit=1, | |
| max_size=1, | |
| api_open=False, | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| max_threads=1, | |
| show_error=True, | |
| share=False, | |
| css=""" | |
| h1 { | |
| font-size: 2.5em !important; | |
| font-weight: bold !important; | |
| margin-bottom: 0.5em !important; | |
| } | |
| .tab-nav button { | |
| font-size: 2em !important; | |
| font-weight: bold !important; | |
| } | |
| """, | |
| ) | |