lea97338 commited on
Commit
ea08658
·
verified ·
1 Parent(s): eb8610d

Update dd.py

Browse files
Files changed (1) hide show
  1. dd.py +23 -64
dd.py CHANGED
@@ -1,84 +1,45 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- from typing import Union, List, Optional
3
  import torch
 
 
4
 
5
 
6
- def format_text_input(prompts: List[str], system_message: str = None):
7
- # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
8
- # when truncation is enabled. The processor counts [IMG] tokens and fails
9
- # if the count changes after truncation.
10
- cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
11
-
12
- return [
13
- [
14
- {
15
- "role": "system",
16
- "content": [{"type": "text", "text": system_message}],
17
- },
18
- {"role": "user", "content": [{"type": "text", "text": prompt}]},
19
- ]
20
- for prompt in cleaned_txt
21
- ]
22
-
23
-
24
- def get_mistral_3_small_prompt_embeds(
25
  text_encoder: AutoModelForCausalLM,
26
  tokenizer: AutoTokenizer,
27
  prompt: Union[str, List[str]],
28
  max_sequence_length: int = 512,
29
- system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
30
- attribution and actions without speculation.""",
31
- hidden_states_layers: List[int] = (10, 20, 30),
32
  ):
33
  prompt = [prompt] if isinstance(prompt, str) else prompt
34
 
35
- # Format input messages
36
- messages_batch = format_text_input(prompts=prompt, system_message=system_message)
37
-
38
- # Process all messages at once
39
- inputs = tokenizer.apply_chat_template(
40
- messages_batch,
41
- add_generation_prompt=False,
42
- tokenize=True,
43
- return_dict=True,
44
  return_tensors="pt",
 
 
45
  max_length=max_sequence_length,
46
- )
47
-
48
- # Move to device
49
- input_ids = inputs["input_ids"].to(text_encoder.device)
50
- attention_mask = inputs["attention_mask"].to(text_encoder.device)
51
 
52
- # Forward pass through the model
53
  with torch.inference_mode():
54
- output = text_encoder(
55
- input_ids=input_ids,
56
- attention_mask=attention_mask,
57
  output_hidden_states=True,
58
  use_cache=False,
59
  )
60
 
61
- # Only use outputs from intermediate layers and stack them
62
- out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
63
- out = out.to(dtype=text_encoder.dtype, device=text_encoder.device)
64
-
65
- batch_size, num_channels, seq_len, hidden_dim = out.shape
66
- prompt_embeds = out.permute(0, 2, 1, 3).reshape(
67
- batch_size, seq_len, num_channels * hidden_dim
68
- )
69
 
70
- return prompt_embeds
71
 
72
 
73
- def prepare_text_ids(
74
- x: torch.Tensor, # (B, L, D) or (L, D)
75
- t_coord: Optional[torch.Tensor] = None,
76
- ):
77
  B, L, _ = x.shape
78
  out_ids = []
79
 
80
  for i in range(B):
81
- t = torch.arange(1) if t_coord is None else t_coord[i]
82
  h = torch.arange(1)
83
  w = torch.arange(1)
84
  l = torch.arange(L)
@@ -97,23 +58,21 @@ def encode_prompt(
97
  prompt_embeds: Optional[torch.Tensor] = None,
98
  max_sequence_length: int = 512,
99
  ):
100
- if prompt is None:
101
- prompt = ""
102
-
103
- prompt = [prompt] if isinstance(prompt, str) else prompt
104
-
105
  if prompt_embeds is None:
106
- prompt_embeds = get_mistral_3_small_prompt_embeds(
107
  text_encoder=text_encoder,
108
  tokenizer=tokenizer,
109
  prompt=prompt,
110
  max_sequence_length=max_sequence_length,
111
  )
112
 
113
- batch_size, seq_len, _ = prompt_embeds.shape
 
 
114
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
115
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
116
 
117
  text_ids = prepare_text_ids(prompt_embeds)
118
  text_ids = text_ids.to(text_encoder.device)
 
119
  return prompt_embeds, text_ids
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from typing import List, Union, Optional
4
 
5
 
6
+ def get_qwen_prompt_embeds(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  text_encoder: AutoModelForCausalLM,
8
  tokenizer: AutoTokenizer,
9
  prompt: Union[str, List[str]],
10
  max_sequence_length: int = 512,
11
+ hidden_layer: int = -1, # dernière couche
 
 
12
  ):
13
  prompt = [prompt] if isinstance(prompt, str) else prompt
14
 
15
+ # Tokenisation simple (pas de chat template)
16
+ inputs = tokenizer(
17
+ prompt,
 
 
 
 
 
 
18
  return_tensors="pt",
19
+ padding=True,
20
+ truncation=True,
21
  max_length=max_sequence_length,
22
+ ).to(text_encoder.device)
 
 
 
 
23
 
 
24
  with torch.inference_mode():
25
+ outputs = text_encoder(
26
+ **inputs,
 
27
  output_hidden_states=True,
28
  use_cache=False,
29
  )
30
 
31
+ # hidden_states[-1] = dernière couche
32
+ hidden = outputs.hidden_states[hidden_layer] # [B, L, D]
 
 
 
 
 
 
33
 
34
+ return hidden # pas de concat, pas de reshape
35
 
36
 
37
+ def prepare_text_ids(x: torch.Tensor):
 
 
 
38
  B, L, _ = x.shape
39
  out_ids = []
40
 
41
  for i in range(B):
42
+ t = torch.arange(1)
43
  h = torch.arange(1)
44
  w = torch.arange(1)
45
  l = torch.arange(L)
 
58
  prompt_embeds: Optional[torch.Tensor] = None,
59
  max_sequence_length: int = 512,
60
  ):
 
 
 
 
 
61
  if prompt_embeds is None:
62
+ prompt_embeds = get_qwen_prompt_embeds(
63
  text_encoder=text_encoder,
64
  tokenizer=tokenizer,
65
  prompt=prompt,
66
  max_sequence_length=max_sequence_length,
67
  )
68
 
69
+ B, L, D = prompt_embeds.shape
70
+
71
+ # répéter pour plusieurs images
72
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
73
+ prompt_embeds = prompt_embeds.view(B * num_images_per_prompt, L, D)
74
 
75
  text_ids = prepare_text_ids(prompt_embeds)
76
  text_ids = text_ids.to(text_encoder.device)
77
+
78
  return prompt_embeds, text_ids