HayatoHongoEveryonesAI commited on
Commit
6b1d95e
·
1 Parent(s): ebce892
Files changed (3) hide show
  1. app.py +32 -113
  2. inference.py +0 -75
  3. vlm_inference.py +40 -21
app.py CHANGED
@@ -2,57 +2,22 @@
2
  import gradio as gr
3
  import spaces
4
  import torch
5
- import tiktoken
6
- from huggingface_hub import hf_hub_download
7
  from PIL import Image
8
 
9
- from model import GPT, ModelConfig
10
- from inference import generate_stream
11
  from vlm_inference import (
12
- build_vlm_model,
13
  vlm_infer_stream,
14
  image_processor,
15
  )
16
 
17
  # =====================================================
18
- # Text-only LLM load (CPU)
19
  # =====================================================
20
- TEXT_REPO_ID = "HayatoHongo/everyoneschat-checkpoints"
21
- TEXT_FILENAME = "model_sft.pt"
22
-
23
- text_ckpt_path = hf_hub_download(
24
- repo_id=TEXT_REPO_ID,
25
- filename=TEXT_FILENAME,
26
- )
27
-
28
- text_state_dict = torch.load(text_ckpt_path, map_location="cpu")
29
-
30
- text_config = ModelConfig(
31
- embedding_dim=1280,
32
- hidden_dim=5120,
33
- num_attention_heads=10,
34
- layer_count=20,
35
- max_sequence_length=2048,
36
- rope_theta=1_000_000.0,
37
- vocab_size=50257,
38
- )
39
-
40
- text_model = GPT(text_config)
41
- text_model.load_state_dict(text_state_dict)
42
- text_model.eval()
43
-
44
- tokenizer = tiktoken.get_encoding("gpt2")
45
- EOS_ID = 50256
46
-
47
-
48
- # =====================================================
49
- # Vision-Language Model load (CPU)
50
- # =====================================================
51
- vlm_model = build_vlm_model() # CPU load, frozen
52
 
53
 
54
  # =====================================================
55
- # Router (GPU only here)
56
  # =====================================================
57
  @spaces.GPU
58
  def chat_fn(
@@ -63,80 +28,37 @@ def chat_fn(
63
  top_p,
64
  top_k,
65
  ):
66
- device = "cuda"
67
-
68
- # ==============================
69
- # Text-only route
70
- # ==============================
71
  if image is None:
72
- model_gpu = text_model.to(device)
73
-
74
- # reset KV cache
75
- for block in model_gpu.blocks:
76
- block.multihead_attention.reset_cache()
77
-
78
- prompt = (
79
- "<user>\n"
80
- f"{message}"
81
- "<assistant>\n"
82
- )
83
 
84
- input_ids = torch.tensor(
85
- [tokenizer.encode(prompt, allowed_special="all")],
86
- device=device
87
- )
88
-
89
- output = ""
90
-
91
- with torch.no_grad(), torch.autocast(
92
- device_type="cuda",
93
- dtype=torch.bfloat16,
 
 
 
 
 
 
 
 
 
 
 
94
  ):
95
- for tid in generate_stream(
96
- model_gpu,
97
- input_ids,
98
- max_new_tokens=256,
99
- temperature=temperature,
100
- top_p=top_p if top_p > 0 else None,
101
- top_k=top_k if top_k > 0 else None,
102
- ):
103
- if tid == EOS_ID:
104
- break
105
- output += tokenizer.decode([tid])
106
 
107
  model_gpu.to("cpu")
108
  torch.cuda.empty_cache()
109
- return output
110
-
111
- # ==============================
112
- # Vision route
113
- # ==============================
114
- else:
115
- model_gpu = vlm_model.to(device)
116
-
117
- image_tensor = image_processor(
118
- images=image.convert("RGB"),
119
- return_tensors="pt"
120
- )["pixel_values"].to(device)
121
-
122
- prompt = ({message})
123
-
124
- def stream():
125
- for chunk in vlm_infer_stream(
126
- model=model_gpu,
127
- image_tensor=image_tensor,
128
- prompt=prompt,
129
- max_new_tokens=256,
130
- temperature=temperature,
131
- top_p=top_p if top_p > 0 else None,
132
- top_k=top_k if top_k > 0 else None,
133
- ):
134
- yield chunk
135
-
136
- model_gpu.to("cpu")
137
- torch.cuda.empty_cache()
138
 
139
- return stream()
140
 
141
 
142
  # =====================================================
@@ -145,14 +67,11 @@ def chat_fn(
145
  demo = gr.ChatInterface(
146
  fn=chat_fn,
147
  multimodal=True,
148
- title="EveryonesGPT (Text + Vision)",
149
- description=(
150
- "- Text only → fast LLM\n"
151
- "- Image + Text → CLIP-VLM\n"
152
- ),
153
  additional_inputs=[
154
- gr.Image(type="pil", label="Image (optional)"),
155
- gr.Slider(0.1, 2.0, value=0.7, step=0.05, label="Temperature"),
156
  gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"),
157
  gr.Slider(0, 200, value=0, step=1, label="Top-k"),
158
  ],
 
2
  import gradio as gr
3
  import spaces
4
  import torch
 
 
5
  from PIL import Image
6
 
 
 
7
  from vlm_inference import (
8
+ load_vlm_model,
9
  vlm_infer_stream,
10
  image_processor,
11
  )
12
 
13
  # =====================================================
14
+ # Load model on CPU (ZeroGPU)
15
  # =====================================================
16
+ model = load_vlm_model() # CPU load, eval
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  # =====================================================
20
+ # GPU inference (VLM only)
21
  # =====================================================
22
  @spaces.GPU
23
  def chat_fn(
 
28
  top_p,
29
  top_k,
30
  ):
 
 
 
 
 
31
  if image is None:
32
+ return "Please upload an image."
 
 
 
 
 
 
 
 
 
 
33
 
34
+ device = "cuda"
35
+ model_gpu = model.to(device)
36
+
37
+ image_tensor = image_processor(
38
+ images=image.convert("RGB"),
39
+ return_tensors="pt"
40
+ )["pixel_values"].to(device)
41
+
42
+ prompt = (
43
+ f"{message}"
44
+ )
45
+
46
+ def stream():
47
+ for chunk in vlm_infer_stream(
48
+ model=model_gpu,
49
+ image_tensor=image_tensor,
50
+ prompt=prompt,
51
+ max_new_tokens=256,
52
+ temperature=temperature,
53
+ top_p=top_p if top_p > 0 else None,
54
+ top_k=top_k if top_k > 0 else None,
55
  ):
56
+ yield chunk
 
 
 
 
 
 
 
 
 
 
57
 
58
  model_gpu.to("cpu")
59
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ return stream()
62
 
63
 
64
  # =====================================================
 
67
  demo = gr.ChatInterface(
68
  fn=chat_fn,
69
  multimodal=True,
70
+ title="EveryonesGPT Vision (CLIP)",
71
+ description="Vision-only VLM demo (CLIP ViT-L/14)",
 
 
 
72
  additional_inputs=[
73
+ gr.Image(type="pil", label="Image"),
74
+ gr.Slider(0.1, 2.0, value=0.5, step=0.05, label="Temperature"),
75
  gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p"),
76
  gr.Slider(0, 200, value=0, step=1, label="Top-k"),
77
  ],
inference.py DELETED
@@ -1,75 +0,0 @@
1
- # inference.py
2
- import torch
3
- import torch.nn.functional as F
4
-
5
- def generate_stream(
6
- model,
7
- input_ids,
8
- max_new_tokens,
9
- temperature,
10
- top_p=None,
11
- top_k=None,
12
- ):
13
- """
14
- ストリーミング生成(batch size = 1 固定)
15
- - GPT.generate と同じロジック
16
- - KV cache 使用
17
- - top-k / top-p 対応
18
- """
19
- model.eval()
20
- next_token = None
21
-
22
- with torch.no_grad():
23
- for i in range(max_new_tokens):
24
-
25
- # ===== forward =====
26
- if i == 0:
27
- logits, _ = model(input_ids, None, use_cache=True)
28
- else:
29
- logits, _ = model(next_token, None, use_cache=True)
30
-
31
- # last token logits
32
- last_logits = logits[:, -1, :] / temperature # [1, vocab]
33
-
34
- # ===== top-k =====
35
- if top_k is not None:
36
- top_k = min(top_k, last_logits.size(-1))
37
- values, _ = torch.topk(last_logits, top_k)
38
- min_value = values[:, -1].unsqueeze(-1)
39
- last_logits = torch.where(
40
- last_logits < min_value,
41
- torch.full_like(last_logits, float("-inf")),
42
- last_logits,
43
- )
44
-
45
- # ===== top-p (nucleus) =====
46
- if top_p is not None:
47
- sorted_logits, sorted_indices = torch.sort(
48
- last_logits, descending=True
49
- )
50
- sorted_probs = F.softmax(sorted_logits, dim=-1)
51
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
52
-
53
- sorted_mask = cumulative_probs > top_p
54
- # ★ ここが重要:clone() を入れる
55
- sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
56
- sorted_mask[..., 0] = False
57
-
58
- sorted_logits = torch.where(
59
- sorted_mask,
60
- torch.full_like(sorted_logits, float("-inf")),
61
- sorted_logits,
62
- )
63
-
64
- last_logits = torch.zeros_like(last_logits).scatter(
65
- -1, sorted_indices, sorted_logits
66
- )
67
-
68
- # ===== sample =====
69
- probs = F.softmax(last_logits, dim=-1)
70
- next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
71
-
72
- yield int(next_token.item())
73
-
74
- # 次ステップ用に連結
75
- input_ids = torch.cat([input_ids, next_token], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vlm_inference.py CHANGED
@@ -6,13 +6,13 @@ import tiktoken
6
  from huggingface_hub import hf_hub_download
7
  from transformers import CLIPVisionModel, CLIPImageProcessor
8
 
9
- from model import GPT, ModelConfig
10
 
11
  # =====================================================
12
  # Constants
13
  # =====================================================
14
- VISION_REPO_ID = "HayatoHongo/everyoneschat-checkpoints"
15
- VISION_FILENAME = "checkpoint_015000_vision_pretrained.pt"
16
 
17
  VISION_ENCODER = "openai/clip-vit-large-patch14"
18
  NUM_IMAGE_PATCHES = 256
@@ -24,6 +24,26 @@ tokenizer = tiktoken.get_encoding("gpt2")
24
  image_processor = CLIPImageProcessor.from_pretrained(VISION_ENCODER)
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # =====================================================
28
  # VLM wrapper
29
  # =====================================================
@@ -47,12 +67,13 @@ class VLM(nn.Module):
47
 
48
 
49
  # =====================================================
50
- # Build model (CPU)
51
  # =====================================================
52
- def build_vlm_model():
53
  ckpt_path = hf_hub_download(
54
- repo_id=VISION_REPO_ID,
55
- filename=VISION_FILENAME,
 
56
  )
57
 
58
  checkpoint = torch.load(ckpt_path, map_location="cpu")
@@ -63,10 +84,9 @@ def build_vlm_model():
63
  torch, config_dict["autocast_dtype"].split(".")[-1]
64
  )
65
 
66
- config = ModelConfig(**{
67
- k: v for k, v in config_dict.items()
68
- if k in ModelConfig.__annotations__
69
- })
70
 
71
  llm = GPT(config)
72
  model = VLM(llm)
@@ -77,7 +97,7 @@ def build_vlm_model():
77
 
78
 
79
  # =====================================================
80
- # Inference helpers
81
  # =====================================================
82
  @torch.no_grad()
83
  def vlm_prefill(model, image_tensor, input_ids):
@@ -105,12 +125,12 @@ def vlm_next_token(model, input_ids, temperature, top_k, top_p):
105
  logits = model.llm.vocab_projection(x)[:, -1, :] / temperature
106
 
107
  if top_k:
108
- v, _ = torch.topk(logits, top_k)
109
  logits = torch.where(logits < v[:, -1:], -float("inf"), logits)
110
 
111
  if top_p:
112
  s_logits, s_idx = torch.sort(logits, descending=True)
113
- probs = torch.softmax(s_logits, dim=-1)
114
  cum = probs.cumsum(dim=-1)
115
  mask = cum > top_p
116
  mask[..., 1:] = mask[..., :-1].clone()
@@ -118,7 +138,7 @@ def vlm_next_token(model, input_ids, temperature, top_k, top_p):
118
  s_logits[mask] = -float("inf")
119
  logits = torch.zeros_like(logits).scatter(-1, s_idx, s_logits)
120
 
121
- probs = torch.softmax(logits, dim=-1)
122
  return torch.multinomial(probs, 1)
123
 
124
 
@@ -126,15 +146,15 @@ def vlm_infer_stream(
126
  model,
127
  image_tensor,
128
  prompt,
129
- max_new_tokens,
130
- temperature,
131
  top_k=None,
132
  top_p=None,
133
  stop_ids={50256},
134
  ):
135
  device = next(model.parameters()).device
136
-
137
  prompt_ids = tokenizer.encode(prompt, allowed_special="all")
 
138
  input_ids = (
139
  [PAD_TOKEN_ID] * NUM_IMAGE_PATCHES + prompt_ids
140
  )
@@ -145,11 +165,10 @@ def vlm_infer_stream(
145
 
146
  x = vlm_prefill(model, image_tensor, input_ids)
147
  logits = model.llm.vocab_projection(x)[:, -1, :] / temperature
148
- probs = torch.softmax(logits, dim=-1)
149
  next_token = torch.multinomial(probs, 1)
150
 
151
- acc = []
152
- last = ""
153
 
154
  for _ in range(max_new_tokens):
155
  tid = int(next_token.item())
 
6
  from huggingface_hub import hf_hub_download
7
  from transformers import CLIPVisionModel, CLIPImageProcessor
8
 
9
+ from model import GPT
10
 
11
  # =====================================================
12
  # Constants
13
  # =====================================================
14
+ REPO_ID = "HayatoHongo/everyoneschat-checkpoints"
15
+ FILENAME = "checkpoint_015000_vision_pretrained.pt"
16
 
17
  VISION_ENCODER = "openai/clip-vit-large-patch14"
18
  NUM_IMAGE_PATCHES = 256
 
24
  image_processor = CLIPImageProcessor.from_pretrained(VISION_ENCODER)
25
 
26
 
27
+ # =====================================================
28
+ # ModelConfig (same as Colab)
29
+ # =====================================================
30
+ from dataclasses import dataclass, fields
31
+
32
+ @dataclass
33
+ class ModelConfig:
34
+ input_sequence_length: int
35
+ max_sequence_length: int
36
+ embedding_dim: int
37
+ hidden_dim: int
38
+ num_attention_heads: int
39
+ layer_count: int
40
+ rope_theta: float
41
+ vocab_size: int
42
+ device_type: str
43
+ random_seed_value: int
44
+ autocast_dtype: torch.dtype
45
+
46
+
47
  # =====================================================
48
  # VLM wrapper
49
  # =====================================================
 
67
 
68
 
69
  # =====================================================
70
+ # Load model (CPU)
71
  # =====================================================
72
+ def load_vlm_model():
73
  ckpt_path = hf_hub_download(
74
+ repo_id=REPO_ID,
75
+ filename=FILENAME,
76
+ repo_type="model"
77
  )
78
 
79
  checkpoint = torch.load(ckpt_path, map_location="cpu")
 
84
  torch, config_dict["autocast_dtype"].split(".")[-1]
85
  )
86
 
87
+ model_config_fields = {f.name for f in fields(ModelConfig)}
88
+ filtered = {k: v for k, v in config_dict.items() if k in model_config_fields}
89
+ config = ModelConfig(**filtered)
 
90
 
91
  llm = GPT(config)
92
  model = VLM(llm)
 
97
 
98
 
99
  # =====================================================
100
+ # Inference helpers (Colab準拠)
101
  # =====================================================
102
  @torch.no_grad()
103
  def vlm_prefill(model, image_tensor, input_ids):
 
125
  logits = model.llm.vocab_projection(x)[:, -1, :] / temperature
126
 
127
  if top_k:
128
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
129
  logits = torch.where(logits < v[:, -1:], -float("inf"), logits)
130
 
131
  if top_p:
132
  s_logits, s_idx = torch.sort(logits, descending=True)
133
+ probs = F.softmax(s_logits, dim=-1)
134
  cum = probs.cumsum(dim=-1)
135
  mask = cum > top_p
136
  mask[..., 1:] = mask[..., :-1].clone()
 
138
  s_logits[mask] = -float("inf")
139
  logits = torch.zeros_like(logits).scatter(-1, s_idx, s_logits)
140
 
141
+ probs = F.softmax(logits, dim=-1)
142
  return torch.multinomial(probs, 1)
143
 
144
 
 
146
  model,
147
  image_tensor,
148
  prompt,
149
+ max_new_tokens=256,
150
+ temperature=0.7,
151
  top_k=None,
152
  top_p=None,
153
  stop_ids={50256},
154
  ):
155
  device = next(model.parameters()).device
 
156
  prompt_ids = tokenizer.encode(prompt, allowed_special="all")
157
+
158
  input_ids = (
159
  [PAD_TOKEN_ID] * NUM_IMAGE_PATCHES + prompt_ids
160
  )
 
165
 
166
  x = vlm_prefill(model, image_tensor, input_ids)
167
  logits = model.llm.vocab_projection(x)[:, -1, :] / temperature
168
+ probs = F.softmax(logits, dim=-1)
169
  next_token = torch.multinomial(probs, 1)
170
 
171
+ acc, last = [], ""
 
172
 
173
  for _ in range(max_new_tokens):
174
  tid = int(next_token.item())