broadfield-dev commited on
Commit
0ec3e05
Β·
verified Β·
1 Parent(s): 63af298

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -30
app.py CHANGED
@@ -3,6 +3,9 @@
3
  Overthinker - Local 4B Quantized Edition (Nemotron 3 Nano 4B)
4
  Uses a local 4B model (NVIDIA Nemotron 3 Nano 4B) loaded in 4-bit quantization if supported,
5
  otherwise falls back to BF16 (which fits easily on 24GB GPUs).
 
 
 
6
  """
7
 
8
  import os
@@ -12,14 +15,14 @@ import uuid
12
  import sqlite3
13
  import torch
14
  from pathlib import Path
15
- from typing import Optional, Dict, List, Any
16
 
17
  from gradio import Server
18
  from fastapi import HTTPException
19
  from starlette.responses import HTMLResponse, PlainTextResponse, JSONResponse
20
  from datasets import Dataset, concatenate_datasets, load_dataset
21
  import pandas as pd
22
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
23
  from bag import (
24
  BASE_URL,
25
  LLMS_TXT,
@@ -29,7 +32,7 @@ from bag import (
29
  VIDEO_PAGE_HTML,
30
  README_MD
31
  )
32
- os.system("pip install torch && git clone https://github.com/state-spaces/mamba.git && cd mamba && python setup.py install")
33
  # ---------------------------------------------------------------------------
34
  # Application Setup
35
  # ---------------------------------------------------------------------------
@@ -39,14 +42,23 @@ DATA_DIR = Path("data")
39
  DATA_DIR.mkdir(exist_ok=True)
40
 
41
  # ---------- Local Model Configuration ----------
42
- # Using NVIDIA Nemotron 3 Nano 4B (BF16) - a compact Mamba2-Transformer hybrid SLM
43
- # 4-bit quantization via BitsAndBytes may not support Mamba layers fully;
44
- # we attempt it first, then fall back to BF16 (model is ~8GB, fits on A10G/T4)
45
- MODEL_NAME = "nvidia/NVIDIA-Nemotron-3-Nano-4B-FP8"
46
 
47
  print("[Overthinker] Attempting to load Nemotron 3 Nano 4B with 4-bit quantization...")
48
 
49
- # Try 4-bit first; if incompatibility with Mamba layers, fallback to BF16
 
 
 
 
 
 
 
 
 
 
 
 
50
  bnb_config = BitsAndBytesConfig(
51
  load_in_4bit=True,
52
  bnb_4bit_use_double_quant=True,
@@ -54,22 +66,30 @@ bnb_config = BitsAndBytesConfig(
54
  bnb_4bit_compute_dtype=torch.bfloat16
55
  )
56
 
 
57
  try:
58
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False)
 
 
59
  model = AutoModelForCausalLM.from_pretrained(
60
  MODEL_NAME,
61
- torch_dtype=torch.bfloat16,
 
 
62
  trust_remote_code=True,
63
- device_map="auto"
64
  )
65
  print(f"[Overthinker] Model loaded in 4-bit quantization on device: {model.device}")
66
  loaded_quantized = True
67
  except Exception as e:
68
  print(f"[Overthinker] 4-bit quantization failed: {e}")
69
  print("[Overthinker] Falling back to BF16 (no quantization) - model is only ~8GB.")
70
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False)
 
 
71
  model = AutoModelForCausalLM.from_pretrained(
72
  MODEL_NAME,
 
73
  device_map="auto",
74
  trust_remote_code=True,
75
  torch_dtype=torch.bfloat16
@@ -109,7 +129,7 @@ def init_session(session_id: str):
109
  type TEXT NOT NULL,
110
  label TEXT NOT NULL,
111
  description TEXT DEFAULT '',
112
- emoji TEXT DEFAULT '\U0001f539',
113
  tips TEXT DEFAULT '[]',
114
  order_index INTEGER DEFAULT 0,
115
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -118,7 +138,7 @@ def init_session(session_id: str):
118
  root_id = str(uuid.uuid4())
119
  conn.execute(
120
  "INSERT INTO nodes (id, parent_id, type, label, description, emoji) VALUES (?, ?, ?, ?, ?, ?)",
121
- (root_id, None, "root", "What decision do you want to explore?", "", "\U0001f333")
122
  )
123
  conn.commit()
124
  conn.close()
@@ -162,7 +182,7 @@ def get_children_db(session_id: str, parent_id: str) -> List[Dict]:
162
  return result
163
 
164
  def add_node_db(session_id: str, parent_id: str, node_type: str, label: str,
165
- description: str = "", emoji: str = "\U0001f539",
166
  tips: list = None, order_index: int = 0) -> Dict:
167
  node_id = str(uuid.uuid4())
168
  tips_json = json.dumps(tips or [])
@@ -219,7 +239,7 @@ def build_path_string(session_id: str, node_id: str) -> str:
219
  parts.append(f"[INPUT] {label}")
220
  elif t == "outcome":
221
  parts.append(f"[OUTCOME] {label}")
222
- return " \u2192 ".join(parts)
223
 
224
  def get_root_node(session_id: str) -> Optional[Dict]:
225
  db_path = get_db_path(session_id)
@@ -378,7 +398,7 @@ def parse_json_response(text: str) -> Optional[dict]:
378
  return None
379
 
380
  # ---------------------------------------------------------------------------
381
- # Routes (All POST, no GET except for serving index)
382
  # ---------------------------------------------------------------------------
383
 
384
  @app.get("/")
@@ -412,7 +432,7 @@ async def create_tree(request: dict):
412
  raise HTTPException(status_code=500, detail="Failed to generate root node. Please check model availability.")
413
  label = parsed.get('label', f'Overthinking: {decision[:40]}')
414
  description = parsed.get('description', f'You are overthinking: {decision}')
415
- emoji = parsed.get('emoji', '\U0001f333')
416
  tips = parsed.get('tips', ['Start by exploring options.'])
417
  update_root_db(session_id, label, description)
418
  db_path = get_db_path(session_id)
@@ -472,7 +492,7 @@ async def get_children(request: dict):
472
  for i, child in enumerate(children_data):
473
  label = child.get('label', 'Unknown')
474
  description = child.get('description', '')
475
- emoji = child.get('emoji', '\U0001f539')
476
  tips = child.get('tips', [f'Consider this {next_type}.'])
477
  existing = get_children_db(session_id, node_id)
478
  existing_labels = [c['label'] for c in existing]
@@ -513,7 +533,7 @@ async def add_options(request: dict):
513
  for i, child in enumerate(children_data):
514
  label = child.get('label', 'Unknown')
515
  description = child.get('description', '')
516
- emoji = child.get('emoji', '\U0001f539')
517
  tips = child.get('tips', [f'Additional {next_type}.'])
518
  existing = get_children_db(session_id, node_id)
519
  existing_labels = [c['label'] for c in existing]
@@ -590,15 +610,15 @@ async def export_path_md(request: dict):
590
  if not session_id or not node_id:
591
  raise HTTPException(status_code=400, detail="Missing session_id or node_id")
592
  path = get_path_db(session_id, node_id)
593
- md = '# \U0001f9e0 Overthinker \u2014 Decision Path\n\n'
594
  for i, node in enumerate(path):
595
  indent = ' ' * i
596
- emoji = {'root': '\U0001f333', 'input': '\U0001f9e0', 'outcome': '\U0001f4ca'}.get(node.get('type', ''), '\U0001f4cc')
597
  md += f'{indent}{emoji} **{node.get("label", "")}**\n'
598
  if node.get('description'):
599
  md += f'{indent} > {node.get("description", "")}\n'
600
  if node.get('tips') and len(node['tips']) > 0:
601
- md += f'{indent} > \U0001f4a1 {node["tips"][0]}\n'
602
  md += '\n'
603
  return PlainTextResponse(content=md, status_code=200)
604
 
@@ -630,17 +650,18 @@ async def get_video():
630
  # Launch
631
  # ---------------------------------------------------------------------------
632
  if __name__ == "__main__":
633
- print(f"\U0001f9e0 Overthinker \u2014 Local 4B Quantized Edition on port {PORT}")
634
- print(f"\U0001f916 Model: {MODEL_NAME}")
 
635
  if loaded_quantized:
636
- print("\U0001f4be Quantization: 4-bit NF4 (BitsAndBytes)")
637
  else:
638
- print("\U0001f4be Quantization: None (BF16 fallback)")
639
- print(f"\U0001f310 Open http://localhost:{PORT} in your browser")
640
  if not HF_TOKEN or not HF_DATASET_REPO:
641
- print("\u26a0\ufe0f No HF_TOKEN or HF_DATASET_REPO set. Upload will fail.")
642
  app.launch(
643
  server_port=PORT,
644
  show_error=True,
645
  share=False
646
- )
 
3
  Overthinker - Local 4B Quantized Edition (Nemotron 3 Nano 4B)
4
  Uses a local 4B model (NVIDIA Nemotron 3 Nano 4B) loaded in 4-bit quantization if supported,
5
  otherwise falls back to BF16 (which fits easily on 24GB GPUs).
6
+
7
+ Handles mamba-ssm dependency gracefully by disabling use_mamba_kernels in config
8
+ to use transformers' native PyTorch fallback implementation when mamba-ssm is not available.
9
  """
10
 
11
  import os
 
15
  import sqlite3
16
  import torch
17
  from pathlib import Path
18
+ from typing import Optional, Dict, List
19
 
20
  from gradio import Server
21
  from fastapi import HTTPException
22
  from starlette.responses import HTMLResponse, PlainTextResponse, JSONResponse
23
  from datasets import Dataset, concatenate_datasets, load_dataset
24
  import pandas as pd
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoConfig
26
  from bag import (
27
  BASE_URL,
28
  LLMS_TXT,
 
32
  VIDEO_PAGE_HTML,
33
  README_MD
34
  )
35
+
36
  # ---------------------------------------------------------------------------
37
  # Application Setup
38
  # ---------------------------------------------------------------------------
 
42
  DATA_DIR.mkdir(exist_ok=True)
43
 
44
  # ---------- Local Model Configuration ----------
45
+ MODEL_NAME = "nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16"
 
 
 
46
 
47
  print("[Overthinker] Attempting to load Nemotron 3 Nano 4B with 4-bit quantization...")
48
 
49
+ # Load config and disable mamba kernels to avoid mamba-ssm dependency
50
+ print("[Overthinker] Loading model config...")
51
+ config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
52
+
53
+ # Disable mamba kernels to use transformers' native PyTorch fallback
54
+ # This avoids needing mamba-ssm and causal-conv1d packages
55
+ if hasattr(config, 'use_mamba_kernels'):
56
+ config.use_mamba_kernels = False
57
+ print("[Overthinker] Disabled use_mamba_kernels - using PyTorch fallback for Mamba layers")
58
+ else:
59
+ print("[Overthinker] Warning: Config does not have use_mamba_kernels attribute")
60
+
61
+ # Try 4-bit first; if incompatibility, fallback to BF16
62
  bnb_config = BitsAndBytesConfig(
63
  load_in_4bit=True,
64
  bnb_4bit_use_double_quant=True,
 
66
  bnb_4bit_compute_dtype=torch.bfloat16
67
  )
68
 
69
+ loaded_quantized = False
70
  try:
71
+ print("[Overthinker] Loading tokenizer...")
72
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
73
+ print("[Overthinker] Loading model with 4-bit quantization...")
74
  model = AutoModelForCausalLM.from_pretrained(
75
  MODEL_NAME,
76
+ config=config,
77
+ quantization_config=bnb_config,
78
+ device_map="auto",
79
  trust_remote_code=True,
80
+ torch_dtype=torch.bfloat16
81
  )
82
  print(f"[Overthinker] Model loaded in 4-bit quantization on device: {model.device}")
83
  loaded_quantized = True
84
  except Exception as e:
85
  print(f"[Overthinker] 4-bit quantization failed: {e}")
86
  print("[Overthinker] Falling back to BF16 (no quantization) - model is only ~8GB.")
87
+ if hasattr(config, 'use_mamba_kernels'):
88
+ config.use_mamba_kernels = False
89
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
90
  model = AutoModelForCausalLM.from_pretrained(
91
  MODEL_NAME,
92
+ config=config,
93
  device_map="auto",
94
  trust_remote_code=True,
95
  torch_dtype=torch.bfloat16
 
129
  type TEXT NOT NULL,
130
  label TEXT NOT NULL,
131
  description TEXT DEFAULT '',
132
+ emoji TEXT DEFAULT 'πŸ”Ή',
133
  tips TEXT DEFAULT '[]',
134
  order_index INTEGER DEFAULT 0,
135
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
 
138
  root_id = str(uuid.uuid4())
139
  conn.execute(
140
  "INSERT INTO nodes (id, parent_id, type, label, description, emoji) VALUES (?, ?, ?, ?, ?, ?)",
141
+ (root_id, None, "root", "What decision do you want to explore?", "", "🌳")
142
  )
143
  conn.commit()
144
  conn.close()
 
182
  return result
183
 
184
  def add_node_db(session_id: str, parent_id: str, node_type: str, label: str,
185
+ description: str = "", emoji: str = "πŸ”Ή",
186
  tips: list = None, order_index: int = 0) -> Dict:
187
  node_id = str(uuid.uuid4())
188
  tips_json = json.dumps(tips or [])
 
239
  parts.append(f"[INPUT] {label}")
240
  elif t == "outcome":
241
  parts.append(f"[OUTCOME] {label}")
242
+ return " β†’ ".join(parts)
243
 
244
  def get_root_node(session_id: str) -> Optional[Dict]:
245
  db_path = get_db_path(session_id)
 
398
  return None
399
 
400
  # ---------------------------------------------------------------------------
401
+ # Routes
402
  # ---------------------------------------------------------------------------
403
 
404
  @app.get("/")
 
432
  raise HTTPException(status_code=500, detail="Failed to generate root node. Please check model availability.")
433
  label = parsed.get('label', f'Overthinking: {decision[:40]}')
434
  description = parsed.get('description', f'You are overthinking: {decision}')
435
+ emoji = parsed.get('emoji', '🌳')
436
  tips = parsed.get('tips', ['Start by exploring options.'])
437
  update_root_db(session_id, label, description)
438
  db_path = get_db_path(session_id)
 
492
  for i, child in enumerate(children_data):
493
  label = child.get('label', 'Unknown')
494
  description = child.get('description', '')
495
+ emoji = child.get('emoji', 'πŸ”Ή')
496
  tips = child.get('tips', [f'Consider this {next_type}.'])
497
  existing = get_children_db(session_id, node_id)
498
  existing_labels = [c['label'] for c in existing]
 
533
  for i, child in enumerate(children_data):
534
  label = child.get('label', 'Unknown')
535
  description = child.get('description', '')
536
+ emoji = child.get('emoji', 'πŸ”Ή')
537
  tips = child.get('tips', [f'Additional {next_type}.'])
538
  existing = get_children_db(session_id, node_id)
539
  existing_labels = [c['label'] for c in existing]
 
610
  if not session_id or not node_id:
611
  raise HTTPException(status_code=400, detail="Missing session_id or node_id")
612
  path = get_path_db(session_id, node_id)
613
+ md = '# 🧠 Overthinker β€” Decision Path\n\n'
614
  for i, node in enumerate(path):
615
  indent = ' ' * i
616
+ emoji = {'root': '🌳', 'input': '🧠', 'outcome': 'πŸ“Š'}.get(node.get('type', ''), 'πŸ“Œ')
617
  md += f'{indent}{emoji} **{node.get("label", "")}**\n'
618
  if node.get('description'):
619
  md += f'{indent} > {node.get("description", "")}\n'
620
  if node.get('tips') and len(node['tips']) > 0:
621
+ md += f'{indent} > πŸ’‘ {node["tips"][0]}\n'
622
  md += '\n'
623
  return PlainTextResponse(content=md, status_code=200)
624
 
 
650
  # Launch
651
  # ---------------------------------------------------------------------------
652
  if __name__ == "__main__":
653
+ print(f"🧠 Overthinker β€” Local 4B Quantized Edition on port {PORT}")
654
+ print(f"πŸ€– Model: {MODEL_NAME}")
655
+ print("πŸ”‹ Mamba kernels: Disabled (using PyTorch fallback - no mamba-ssm/causal-conv1d needed)")
656
  if loaded_quantized:
657
+ print("πŸ’Ύ Quantization: 4-bit NF4 (BitsAndBytes)")
658
  else:
659
+ print("πŸ’Ύ Quantization: None (BF16 fallback - fits in 16GB VRAM)")
660
+ print(f"🌐 Open http://localhost:{PORT} in your browser")
661
  if not HF_TOKEN or not HF_DATASET_REPO:
662
+ print("⚠️ No HF_TOKEN or HF_DATASET_REPO set. Upload will fail.")
663
  app.launch(
664
  server_port=PORT,
665
  show_error=True,
666
  share=False
667
+ )