lea97338 commited on
Commit
68c7eda
·
verified ·
1 Parent(s): da6eaaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -50
app.py CHANGED
@@ -1,60 +1,149 @@
 
 
 
 
 
 
 
 
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import tempfile, os
5
 
6
- # Qwen 2.5 1.5B Instruct
7
- REPO_ID = "Qwen/Qwen2.5-1.5B-Instruct"
8
 
9
- device = "cpu"
10
- dtype = torch.float32
 
 
 
 
 
 
11
 
12
- # Charger UNIQUEMENT le CausalLM
13
- tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- REPO_ID,
16
- torch_dtype=dtype,
17
- low_cpu_mem_usage=True,
18
  )
19
- model.to(device)
20
- model.eval()
21
 
22
- # Projection FLUX2 Klein : 4096 → 7680
23
- project_out = torch.nn.Linear(4096, 7680, bias=False)
 
 
24
 
25
- @torch.no_grad()
26
- def encode_text(prompt: str):
27
- if not prompt.strip():
28
- raise gr.Error("Prompt vide")
29
-
30
- # Tokenisation simple
31
- inputs = tokenizer(
32
- prompt,
33
- return_tensors="pt",
34
- truncation=True,
35
- max_length=256
36
- ).to(device)
37
-
38
- # Sortie Qwen : hidden_states = [1, seq_len, 4096]
39
- outputs = model.model(**inputs, output_hidden_states=True)
40
- #hidden = outputs.hidden_states[-1]
41
-
42
- # Projection FLUX2 Klein
43
- #projected = project_out(hidden) # [1, seq_len, 7680]
44
-
45
- # Sauvegarde
46
- fd, path = tempfile.mkstemp(suffix=".pt")
47
- os.close(fd)
48
- torch.save(outputs, path)
49
-
50
- return path
51
-
52
- demo = gr.Interface(
53
- fn=encode_text,
54
- inputs=gr.Textbox(label="Prompt"),
55
- outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
56
- title="FLUX2 Klein — Encoder Qwen2.5 1.5B",
57
- description="Encode le texte avec Qwen2.5 1.5B + projection FLUX2 (4096→7680).",
58
  )
59
 
60
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import io
4
+ import time
5
+ import tempfile
6
+ import logging
7
+ import spaces
8
+
9
  import torch
10
  import gradio as gr
11
+ from transformers import Mistral3ForConditionalGeneration, AutoProcessor
 
12
 
13
+ from mistral_text_encoding_core import encode_prompt
 
14
 
15
+ # ------------------------------------------------------
16
+ # Logging
17
+ # ------------------------------------------------------
18
+ logging.basicConfig(
19
+ level=os.getenv("LOG_LEVEL", "INFO"),
20
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
21
+ )
22
+ logger = logging.getLogger("mistral-text-encoding-gradio")
23
 
24
+ # ------------------------------------------------------
25
+ # Config
26
+ # ------------------------------------------------------
27
+ TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "mistralai/Mistral-Small-3.2-24B-Instruct-2506")
28
+ TOKENIZER_ID = os.getenv(
29
+ "TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
30
  )
31
+ DTYPE = torch.bfloat16
 
32
 
33
+ # ------------------------------------------------------
34
+ # Global model references
35
+ # ------------------------------------------------------
36
+ logger.info("Loading models...")
37
 
38
+ t0 = time.time()
39
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
40
+ TEXT_ENCODER_ID,
41
+ dtype=DTYPE,
42
+ ).to("cpu")
43
+ logger.info(
44
+ "Loaded Mistral text encoder (%.2fs) dtype=%s device=%s",
45
+ time.time() - t0,
46
+ text_encoder.dtype,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
 
49
+ t1 = time.time()
50
+ tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
51
+ logger.info("Loaded tokenizer in %.2fs", time.time() - t1)
52
+
53
+ torch.set_grad_enabled(False)
54
+
55
+
56
+ def get_vram_info():
57
+ """Get current VRAM usage info."""
58
+ if torch.cuda.is_available():
59
+ return {
60
+ "vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2),
61
+ "vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2),
62
+ "vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2),
63
+ }
64
+ return {"vram": "CUDA not available"}
65
+
66
+ @spaces.GPU()
67
+ def encode_text(prompt: str):
68
+ """Encode text and return a downloadable pytorch file."""
69
+ global text_encoder, tokenizer
70
+
71
+ if text_encoder is None or tokenizer is None:
72
+ return None, "Model not loaded"
73
+
74
+ t0 = time.time()
75
+
76
+ # Handle multiple prompts (one per line)
77
+ prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()]
78
+ if not prompts:
79
+ return None, "Please enter at least one prompt"
80
+
81
+ num_prompts = len(prompts)
82
+ prompt_input = prompts[0] if num_prompts == 1 else prompts
83
+
84
+ logger.info("Encoding %d prompt(s)", num_prompts)
85
+
86
+ prompt_embeds, text_ids = encode_prompt(
87
+ text_encoder=text_encoder,
88
+ tokenizer=tokenizer,
89
+ prompt=prompt_input,
90
+ )
91
+
92
+ duration = (time.time() - t0) * 1000.0
93
+
94
+ logger.info(
95
+ "Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s",
96
+ duration,
97
+ tuple(prompt_embeds.shape),
98
+ tuple(text_ids.shape),
99
+ )
100
+
101
+ # Save tensor to a temporary file
102
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
103
+ torch.save(prompt_embeds.cpu(), temp_file.name)
104
+
105
+ # Clean up GPU tensors
106
+ del prompt_embeds, text_ids
107
+ gc.collect()
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+
111
+ vram = get_vram_info()
112
+ status = (
113
+ f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n"
114
+ f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, "
115
+ f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak"
116
+ )
117
+
118
+ return temp_file.name, status
119
+
120
+
121
+ # ------------------------------------------------------
122
+ # Gradio Interface
123
+ # ------------------------------------------------------
124
+ with gr.Blocks(title="Mistral Text Encoder") as demo:
125
+ gr.Markdown("# Mistral Text Encoder")
126
+ gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ prompt_input = gr.Textbox(
131
+ label="Prompt(s)",
132
+ placeholder="Enter your prompt here...\nOr multiple prompts, one per line",
133
+ lines=5,
134
+ )
135
+ encode_btn = gr.Button("Encode", variant="primary")
136
+
137
+ with gr.Column():
138
+ output_file = gr.File(label="Download Embeddings (.pt)")
139
+ status_output = gr.Textbox(label="Status", interactive=False)
140
+
141
+ encode_btn.click(
142
+ fn=encode_text,
143
+ inputs=[prompt_input],
144
+ outputs=[output_file, status_output],
145
+ )
146
+
147
+
148
+ if __name__ == "__main__":
149
+ demo.launch(server_name="0.0.0.0", server_port=7860)