xiezhe22 commited on
Commit
a9a6d6c
Β·
1 Parent(s): 38ecb1b

Add 8b model

Browse files
Files changed (2) hide show
  1. app.py +77 -15
  2. app_legacy.py +0 -53
app.py CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
3
  import pandas as pd
4
  import numpy as np
5
  import torch
6
- import subprocess
7
  from threading import Thread
8
  from transformers import (
9
  AutoModelForCausalLM,
@@ -13,21 +12,62 @@ from transformers import (
13
  )
14
 
15
  # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
16
- MODEL_NAME = "bytedance-research/ChatTS-14B"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- tokenizer = AutoTokenizer.from_pretrained(
19
- MODEL_NAME, trust_remote_code=True
20
- )
21
- processor = AutoProcessor.from_pretrained(
22
- MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
23
- )
24
- model = AutoModelForCausalLM.from_pretrained(
25
- MODEL_NAME,
26
- trust_remote_code=True,
27
- device_map="auto",
28
- torch_dtype=torch.float16
29
- )
30
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
33
 
@@ -290,6 +330,21 @@ with gr.Blocks(title="ChatTS Demo") as demo:
290
 
291
  with gr.Row():
292
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  upload = gr.File(
294
  label="Upload CSV File",
295
  file_types=[".csv"],
@@ -355,5 +410,12 @@ with gr.Blocks(title="ChatTS Demo") as demo:
355
  outputs=[text_out]
356
  )
357
 
 
 
 
 
 
 
 
358
  if __name__ == '__main__':
359
  demo.launch()
 
3
  import pandas as pd
4
  import numpy as np
5
  import torch
 
6
  from threading import Thread
7
  from transformers import (
8
  AutoModelForCausalLM,
 
12
  )
13
 
14
  # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
15
+ # Default to 8B but keep both variants resident on the GPU.
16
+ DEFAULT_MODEL_NAME = "bytedance-research/ChatTS-8B"
17
+ AVAILABLE_MODEL_NAMES = [
18
+ "bytedance-research/ChatTS-8B",
19
+ "bytedance-research/ChatTS-14B"
20
+ ]
21
+
22
+ MODEL_REGISTRY = {}
23
+
24
+ for name in AVAILABLE_MODEL_NAMES:
25
+ print(f"Loading model into memory: {name}")
26
+ tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
27
+ proc = AutoProcessor.from_pretrained(name, trust_remote_code=True, tokenizer=tok)
28
+ mdl = AutoModelForCausalLM.from_pretrained(
29
+ name,
30
+ trust_remote_code=True,
31
+ device_map="auto",
32
+ torch_dtype=torch.float16
33
+ )
34
+ mdl.eval()
35
+ MODEL_REGISTRY[name] = {
36
+ "tokenizer": tok,
37
+ "processor": proc,
38
+ "model": mdl
39
+ }
40
 
41
+ CURRENT_MODEL_NAME = DEFAULT_MODEL_NAME
42
+
43
+ tokenizer = MODEL_REGISTRY[CURRENT_MODEL_NAME]["tokenizer"]
44
+ processor = MODEL_REGISTRY[CURRENT_MODEL_NAME]["processor"]
45
+ model = MODEL_REGISTRY[CURRENT_MODEL_NAME]["model"]
46
+
47
+
48
+ def load_model_by_name(name: str):
49
+ """Activate the preloaded model by name without reloading weights."""
50
+ global tokenizer, processor, model, CURRENT_MODEL_NAME
51
+
52
+ if name not in MODEL_REGISTRY:
53
+ return f"Model not available: {name}"
54
+
55
+ if name == CURRENT_MODEL_NAME:
56
+ return f"Model already selected: {name}"
57
+
58
+ CURRENT_MODEL_NAME = name
59
+ tokenizer = MODEL_REGISTRY[name]["tokenizer"]
60
+ processor = MODEL_REGISTRY[name]["processor"]
61
+ model = MODEL_REGISTRY[name]["model"]
62
+ model.eval()
63
+
64
+ print(f"Activated model: {name}")
65
+ return f"Active model: {name}"
66
+
67
+
68
+ def switch_model(selected_model_name: str):
69
+ """Wrapper for Gradio to switch models; returns status text."""
70
+ return load_model_by_name(selected_model_name)
71
 
72
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
73
 
 
330
 
331
  with gr.Row():
332
  with gr.Column(scale=1):
333
+ # Model selection UI
334
+ model_radio = gr.Radio(
335
+ choices=["bytedance-research/ChatTS-8B", "bytedance-research/ChatTS-14B"],
336
+ value=CURRENT_MODEL_NAME,
337
+ label="Model Version"
338
+ )
339
+
340
+ model_btn = gr.Button("Load Model")
341
+
342
+ model_status = gr.Textbox(
343
+ label="Model Status",
344
+ value=f"Models in memory: {', '.join(AVAILABLE_MODEL_NAMES)}. Active: {CURRENT_MODEL_NAME}",
345
+ interactive=False
346
+ )
347
+
348
  upload = gr.File(
349
  label="Upload CSV File",
350
  file_types=[".csv"],
 
410
  outputs=[text_out]
411
  )
412
 
413
+ # Wire model loading button
414
+ model_btn.click(
415
+ fn=switch_model,
416
+ inputs=[model_radio],
417
+ outputs=[model_status]
418
+ )
419
+
420
  if __name__ == '__main__':
421
  demo.launch()
app_legacy.py DELETED
@@ -1,53 +0,0 @@
1
- import spaces # for ZeroGPU support
2
- import gradio as gr
3
- import pandas as pd
4
- import numpy as np
5
- import torch
6
- import subprocess
7
- from transformers import (
8
- AutoModelForCausalLM,
9
- AutoTokenizer,
10
- AutoProcessor,
11
- )
12
-
13
- # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
14
- MODEL_NAME = "bytedance-research/ChatTS-14B"
15
-
16
- tokenizer = AutoTokenizer.from_pretrained(
17
- MODEL_NAME, trust_remote_code=True
18
- )
19
- processor = AutoProcessor.from_pretrained(
20
- MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
21
- )
22
- model = AutoModelForCausalLM.from_pretrained(
23
- MODEL_NAME,
24
- trust_remote_code=True,
25
- device_map="auto",
26
- torch_dtype=torch.float16
27
- )
28
- model.eval()
29
-
30
-
31
- # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
- @spaces.GPU
33
- def generate_text(prompt):
34
- inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
35
- outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=512,
38
- do_sample=True,
39
- temperature=0.2,
40
- top_p=0.9
41
- )
42
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
-
44
- demo = gr.Interface(
45
- fn=generate_text,
46
- inputs=gr.Textbox(lines=2, label="Prompt"),
47
- outputs=gr.Textbox(lines=6, label="Generated Text")
48
- )
49
-
50
-
51
- if __name__ == '__main__':
52
- subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
53
- demo.launch()