ariG23498 HF Staff commited on
Commit
719c89d
·
verified ·
1 Parent(s): 41ad958

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +143 -0
main.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query
2
+ from transformers import Mistral3ForConditionalGeneration, AutoProcessor
3
+ from typing import Union, Optional, List
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ device = "cuda"
9
+ model_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
10
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
11
+
12
+ processor_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
13
+ tokenizer = AutoProcessor.from_pretrained(processor_id)
14
+
15
+ def format_text_input(prompts: List[str], system_message: str = None):
16
+ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
17
+ # when truncation is enabled. The processor counts [IMG] tokens and fails
18
+ # if the count changes after truncation.
19
+ cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
20
+
21
+ return [
22
+ [
23
+ {
24
+ "role": "system",
25
+ "content": [{"type": "text", "text": system_message}],
26
+ },
27
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
28
+ ]
29
+ for prompt in cleaned_txt
30
+ ]
31
+
32
+
33
+ def _get_mistral_3_small_prompt_embeds(
34
+ text_encoder: Mistral3ForConditionalGeneration,
35
+ tokenizer: AutoProcessor,
36
+ prompt: Union[str, List[str]],
37
+ dtype: Optional[torch.dtype] = None,
38
+ device: Optional[torch.device] = None,
39
+ max_sequence_length: int = 512,
40
+ system_message: str = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
41
+ attribution and actions without speculation.""",
42
+ hidden_states_layers: List[int] = (10, 20, 30),
43
+ ):
44
+ dtype = text_encoder.dtype if dtype is None else dtype
45
+ device = text_encoder.device if device is None else device
46
+
47
+ prompt = [prompt] if isinstance(prompt, str) else prompt
48
+
49
+ # Format input messages
50
+ messages_batch = format_text_input(prompts=prompt, system_message=system_message)
51
+
52
+ # Process all messages at once
53
+ inputs = tokenizer.apply_chat_template(
54
+ messages_batch,
55
+ add_generation_prompt=False,
56
+ tokenize=True,
57
+ return_dict=True,
58
+ return_tensors="pt",
59
+ padding="max_length",
60
+ truncation=True,
61
+ max_length=max_sequence_length,
62
+ )
63
+
64
+ # Move to device
65
+ input_ids = inputs["input_ids"].to(device)
66
+ attention_mask = inputs["attention_mask"].to(device)
67
+
68
+ # Forward pass through the model
69
+ output = text_encoder(
70
+ input_ids=input_ids,
71
+ attention_mask=attention_mask,
72
+ output_hidden_states=True,
73
+ use_cache=False,
74
+ )
75
+
76
+ # Only use outputs from intermediate layers and stack them
77
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
78
+ out = out.to(dtype=dtype, device=device)
79
+
80
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
81
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
82
+
83
+ return prompt_embeds
84
+
85
+ def _prepare_text_ids(
86
+ x: torch.Tensor, # (B, L, D) or (L, D)
87
+ t_coord: Optional[torch.Tensor] = None,
88
+ ):
89
+ B, L, _ = x.shape
90
+ out_ids = []
91
+
92
+ for i in range(B):
93
+ t = torch.arange(1) if t_coord is None else t_coord[i]
94
+ h = torch.arange(1)
95
+ w = torch.arange(1)
96
+ l = torch.arange(L)
97
+
98
+ coords = torch.cartesian_prod(t, h, w, l)
99
+ out_ids.append(coords)
100
+
101
+ return torch.stack(out_ids)
102
+
103
+ def encode_prompt(
104
+ prompt: Union[str, List[str]],
105
+ device: Optional[torch.device] = None,
106
+ num_images_per_prompt: int = 1,
107
+ prompt_embeds: Optional[torch.Tensor] = None,
108
+ max_sequence_length: int = 512,
109
+ ):
110
+
111
+ if prompt is None:
112
+ prompt = ""
113
+
114
+ prompt = [prompt] if isinstance(prompt, str) else prompt
115
+
116
+ if prompt_embeds is None:
117
+ prompt_embeds = _get_mistral_3_small_prompt_embeds(
118
+ text_encoder=text_encoder,
119
+ tokenizer=tokenizer,
120
+ prompt=prompt,
121
+ device=device,
122
+ max_sequence_length=max_sequence_length,
123
+ )
124
+
125
+ batch_size, seq_len, _ = prompt_embeds.shape
126
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
127
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
128
+
129
+ text_ids = _prepare_text_ids(prompt_embeds)
130
+ text_ids = text_ids.to(device)
131
+ return prompt_embeds, text_ids
132
+
133
+ @app.get("/")
134
+ def read_root():
135
+ return {"message": "API is live. Use the /predict endpoint."}
136
+
137
+ @app.get("/predict")
138
+ def predict(prompt: str = Query(...)):
139
+ output = encode_prompt(
140
+ prompt=prompt,
141
+ device=device,
142
+ )
143
+ return {"response": output}