TimeOmni-1 / app.py
anton-hugging's picture
update-app
2737998
import os
import json
import re
from typing import List, Tuple
from threading import Thread, Lock
is_hf_space = os.getenv("SPACE_ID") is not None
if is_hf_space:
import spaces
else:
class _NoopSpaces:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _NoopSpaces()
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
repo_id = os.getenv("MODEL_REPO_ID", "anton-hugging/TimeOmni-1-7B")
max_new_tokens_limit = min(max(int(os.getenv("MAX_NEW_TOKENS_LIMIT", "4096")), 1), 4096)
device_map = os.getenv("DEVICE_MAP", "auto")
attn_implementation = os.getenv("ATTN_IMPLEMENTATION", "sdpa")
cache_dir = os.getenv("MODEL_CACHE_DIR", "./model_weights")
torch_dtype_name = os.getenv("TORCH_DTYPE", "bfloat16").strip().lower()
default_timeout = 300.0
print("=" * 60)
print(f"Configured model repository: {repo_id}")
print(f"Environment: {'HF Space' if is_hf_space else 'Local'}")
print(f"Default timeout: {default_timeout}s")
print(f"Device map: {device_map}")
print(f"Attention implementation: {attn_implementation}")
print(f"Max new tokens limit: {max_new_tokens_limit}")
print("=" * 60)
matplotlib.use("Agg")
local_model_path = os.path.expanduser(os.getenv("MODEL_PATH", "")).strip()
model_source = repo_id
effective_cache_dir = cache_dir
if local_model_path and os.path.exists(local_model_path):
model_source = local_model_path
effective_cache_dir = None
print(f"Using local model folder at: {model_source}")
else:
print(f"Using Hugging Face model repo: {repo_id}")
print(f"Model cache path: {cache_dir}")
def resolve_torch_dtype(dtype_name: str) -> torch.dtype:
dtype_map = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
return dtype_map.get(dtype_name, torch.bfloat16)
class RuntimeStopCriteria(StoppingCriteria):
def __init__(self, stop_state, start_time: float, timeout_threshold: float):
self.stop_state = stop_state
self.start_time = start_time
self.timeout_threshold = timeout_threshold
def __call__(self, input_ids, scores, **kwargs) -> bool:
import time
if isinstance(self.stop_state, dict) and self.stop_state.get("stop"):
return True
elapsed = time.time() - self.start_time
return elapsed > self.timeout_threshold
tokenizer = None
model = None
model_init_lock = Lock()
def load_model_if_needed():
global tokenizer, model
if model is not None and tokenizer is not None:
return
with model_init_lock:
if model is not None and tokenizer is not None:
return
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(
model_source,
cache_dir=effective_cache_dir,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_source,
cache_dir=effective_cache_dir,
dtype=resolve_torch_dtype(torch_dtype_name),
device_map=device_map,
low_cpu_mem_usage=True,
attn_implementation=attn_implementation,
trust_remote_code=True,
)
model.eval()
print("=" * 60)
print("Model loaded successfully!")
print(f"Model source: {model_source}")
print("=" * 60)
def _first_key(row, keys):
for key in keys:
if key in row and row[key]:
return row[key]
return None
def load_task_prompts(max_tasks: int = 4) -> List[Tuple[str, str, str]]:
dataset_id = "anton-hugging/timeomni-1-testbed"
data_file = "id_test.json"
task_keys = ["task_type"]
system_keys = ["system"]
problem_keys = ["problem"]
tasks: List[Tuple[str, str, str]] = []
seen = set()
cache_dir = "./dataset_cache"
local_dataset_path = os.path.expanduser(os.getenv("DATASET_FILE", ""))
try:
if local_dataset_path and os.path.exists(local_dataset_path):
dataset_path = local_dataset_path
else:
dataset_path = hf_hub_download(
repo_id=dataset_id,
filename=data_file,
repo_type="dataset",
local_dir=cache_dir,
)
with open(dataset_path, "r", encoding="utf-8") as f:
ds = json.load(f)
except Exception:
ds = None
if ds is not None:
for row in ds:
task = _first_key(row, task_keys)
system_prompt = _first_key(row, system_keys)
problem = _first_key(row, problem_keys)
if task and task not in seen:
if not system_prompt:
system_prompt = f"Task: {task}\nFollow the dataset setting for your response."
tasks.append((task, system_prompt, problem or ""))
seen.add(task)
if len(tasks) >= max_tasks:
break
while len(tasks) < max_tasks:
idx = len(tasks) + 1
tasks.append((f"Task {idx}", "Follow the task instructions for your response.", ""))
return tasks
def extract_series(text: str) -> List[List[float]]:
matches = re.findall(r"\[([^\[\]]+)\]", text)
series: List[List[float]] = []
for m in matches:
parts = re.split(r"[,\s]+", m.strip())
nums: List[float] = []
for p in parts:
if not p:
continue
try:
nums.append(float(p))
except ValueError:
nums = []
break
if nums:
series.append(nums)
return series
def build_plot(series: List[List[float]]):
if not series:
return None
fig, ax = plt.subplots(figsize=(6, 3))
for idx, s in enumerate(series):
ax.plot(range(len(s)), s, label=f"S{idx + 1}")
ax.set_xlabel("t")
ax.set_ylabel("value")
ax.grid(True, alpha=0.3)
if len(series) > 1:
ax.legend()
fig.tight_layout()
return fig
def tuples_to_messages(history):
if not history or not isinstance(history, list):
return []
messages = []
for item in history:
if isinstance(item, (list, tuple)) and len(item) == 2:
user_msg, assistant_msg = item
messages.append({"role": "user", "content": str(user_msg or "")})
messages.append({"role": "assistant", "content": str(assistant_msg or "")})
elif isinstance(item, dict) and "role" in item and "content" in item:
messages.append({"role": item["role"], "content": str(item.get("content", ""))})
return messages
def format_task_label(task_label: str) -> str:
return task_label.replace("_", " ").title()
@spaces.GPU
def respond_core(
system_message,
message,
max_tokens,
temperature,
top_p,
stop_state,
plot,
):
import time
load_model_if_needed()
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": message}
]
try:
start_time = time.time()
print(f"[respond_core] Starting generation, max_tokens={max_tokens}")
timeout_threshold = float(os.getenv("GENERATION_TIMEOUT", str(default_timeout)))
safe_max_tokens = min(max(int(max_tokens), 1), max_new_tokens_limit)
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt_text, return_tensors="pt")
input_device = next(model.parameters()).device
inputs = {k: v.to(input_device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
stopping_criteria = StoppingCriteriaList(
[RuntimeStopCriteria(stop_state, start_time, timeout_threshold)]
)
generation_kwargs = {
**inputs,
"max_new_tokens": safe_max_tokens,
"temperature": float(temperature),
"top_p": float(top_p),
"do_sample": True,
"streamer": streamer,
"stopping_criteria": stopping_criteria,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"use_cache": True,
}
generation_thread = Thread(target=model.generate, kwargs=generation_kwargs)
generation_thread.daemon = True
generation_thread.start()
response = ""
chunk_count = 0
last_yield_time = start_time
for token in streamer:
current_time = time.time()
elapsed = current_time - start_time
if isinstance(stop_state, dict) and stop_state.get("stop"):
print(f"[respond_core] Generation stopped by user")
break
if elapsed > timeout_threshold:
print(f"[respond_core] Timeout threshold reached ({elapsed:.2f}s), stopping generation")
if response:
yield response, plot
break
if token:
response += token
chunk_count += 1
if chunk_count == 1:
first_token_time = current_time - start_time
print(f"[respond_core] First token received in {first_token_time:.2f}s, length={len(response)}")
yield response, plot
last_yield_time = current_time
elif current_time - last_yield_time > 5.0:
if response:
yield response, plot
last_yield_time = current_time
generation_thread.join(timeout=1.0)
print(f"[respond_core] Generation completed, total tokens={chunk_count}, response length={len(response)}")
except Exception as e:
import traceback
print(f"[respond_core] Error: {e}")
print(f"[respond_core] Traceback: {traceback.format_exc()}")
yield "", plot
def build_card(task_label: str, system_prompt: str, example_question: str = "", allow_edit_system: bool = False):
with gr.Tab(task_label):
if allow_edit_system:
system_prompt_input = gr.Textbox(
label="System Prompt",
value=system_prompt,
lines=5,
placeholder="Enter your custom system prompt...",
)
current_system_prompt = gr.State(system_prompt)
system_prompt_input.change(
lambda x: x,
inputs=[system_prompt_input],
outputs=[current_system_prompt],
)
else:
with gr.Accordion("Default system prompt", open=True):
gr.Markdown(f"```\n{system_prompt}\n```")
current_system_prompt = gr.State(system_prompt)
plot = gr.Plot(label="Time Series Plot")
chatbot = gr.Chatbot()
chatbot_state = gr.State([])
stop_state = gr.State({"stop": False})
msg = gr.Textbox(
label="Message",
lines=3,
placeholder="Type a message...",
value=example_question or "",
)
max_tokens = gr.Slider(
minimum=1,
maximum=max_new_tokens_limit,
value=max_new_tokens_limit,
step=1,
label="Max new tokens",
)
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.001, maximum=1.0, value=0.001, step=0.05, label="Top-p")
send = gr.Button("Send")
stop = gr.Button("Stop Generation")
def respond(message, chatbot_state, max_tokens, temperature, top_p, stop_state, system_prompt_state):
import time
start_time = time.time()
print(f"[respond] Request received, message length={len(message) if message else 0}")
if not message or not message.strip():
current_chat = tuples_to_messages(chatbot_state) if chatbot_state else []
if not isinstance(current_chat, list):
current_chat = []
return current_chat, chatbot_state, None, stop_state
current_prompt = system_prompt_state or system_prompt
if not isinstance(stop_state, dict):
stop_state = {"stop": False}
else:
stop_state = stop_state.copy()
stop_state["stop"] = False
series = extract_series(message)
plot = build_plot(series)
plot_time = time.time() - start_time
print(f"[respond] Plot generated in {plot_time:.2f}s")
current_chat = tuples_to_messages(chatbot_state) if chatbot_state else []
if not isinstance(current_chat, list):
current_chat = []
yield current_chat, chatbot_state, plot, stop_state
initial_response = [(message, "Generating answer...")]
initial_chat = tuples_to_messages(initial_response)
yield initial_chat, chatbot_state, plot, stop_state
stream = respond_core(
current_prompt,
message,
max_tokens,
temperature,
top_p,
stop_state,
plot,
)
first_token_time = None
has_yielded_content = False
last_response = ""
for response, current_plot in stream:
if response:
last_response = response
if response.strip() and response != "Generating answer...":
if first_token_time is None:
first_token_time = time.time() - start_time
print(f"[respond] First token received in {first_token_time:.2f}s")
updated = [(message, response)]
chat_value = tuples_to_messages(updated)
if not isinstance(chat_value, list):
chat_value = []
yield chat_value, updated, current_plot, stop_state
has_yielded_content = True
if not has_yielded_content:
elapsed = time.time() - start_time
is_hf_space = os.getenv("SPACE_ID") is not None
default_timeout = 300.0
timeout_threshold = float(os.getenv("GENERATION_TIMEOUT", str(default_timeout)))
if elapsed > timeout_threshold:
error_msg = f"Generation timeout ({elapsed:.1f}s). Please try reducing input length or max_tokens (<= {max_new_tokens_limit}), or increase GENERATION_TIMEOUT environment variable (current: {timeout_threshold}s)."
else:
error_msg = "Generation failed, please check logs or retry."
error_response = [(message, error_msg)]
error_chat = tuples_to_messages(error_response)
yield error_chat, chatbot_state, plot, stop_state
print(f"[respond] No content generated, elapsed={elapsed:.2f}s")
total_time = time.time() - start_time
print(f"[respond] Request completed in {total_time:.2f}s")
send_event = send.click(
respond,
inputs=[msg, chatbot_state, max_tokens, temperature, top_p, stop_state, current_system_prompt],
outputs=[chatbot, chatbot_state, plot, stop_state],
)
msg_event = msg.submit(
respond,
inputs=[msg, chatbot_state, max_tokens, temperature, top_p, stop_state, current_system_prompt],
outputs=[chatbot, chatbot_state, plot, stop_state],
)
stop.click(lambda: {"stop": True}, None, stop_state, cancels=[send_event, msg_event])
send_event.then(lambda: "", None, msg)
msg_event.then(lambda: "", None, msg)
with gr.Blocks() as demo:
gr.Markdown("# 🐏 [TimeOmni-1](https://huggingface.co/anton-hugging/TimeOmni-1-7B)")
tasks = load_task_prompts()
task_map = {label: (prompt, example) for label, prompt, example in tasks}
task_order = [
"scenario_understanding",
"causality_discovery",
"event_aware_forecasting",
"decision_making",
]
emojis = ["1️⃣", "2️⃣", "3️⃣", "4️⃣", "5️⃣"]
with gr.Tabs():
for idx, label in enumerate(task_order):
if label in task_map:
tab_label = f"{emojis[idx]} {format_task_label(label)}"
system_prompt, example_question = task_map[label]
build_card(tab_label, system_prompt, example_question, allow_edit_system=False)
build_card(
f"{emojis[4]} Your Custom Question",
"You are a helpful assistant. Output Format:\n<think>Your step-by-step reasoning process that justifies your answer</think>\n<answer>Your final answer</answer>",
allow_edit_system=True,
)
if __name__ == "__main__":
demo.queue(
default_concurrency_limit=1,
max_size=1,
api_open=False,
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
max_threads=1,
show_error=True,
share=False,
css="""
h1 {
font-size: 2.5em !important;
font-weight: bold !important;
margin-bottom: 0.5em !important;
}
.tab-nav button {
font-size: 2em !important;
font-weight: bold !important;
}
""",
)