lea97338 commited on
Commit
75020f3
·
verified ·
1 Parent(s): 769c06d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -50
app.py CHANGED
@@ -1,67 +1,58 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModel
4
- import tempfile
5
- import os
6
-
7
- # ============================
8
- # CONFIG
9
- # ============================
10
-
11
- # Tu peux changer ce modèle par un Mistral quand tu en trouves un adapté CPU
12
- # Exemple possible : "mistralai/Mistral-7B-v0.1" (très lourd pour 12 Go CPU)
13
- # Pour rester safe sur CPU, je mets un modèle plus léger par défaut :
14
- MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
15
 
 
16
  device = "cpu"
17
- dtype = torch.float32 # sur CPU, reste en float32 pour éviter les emmerdes
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=dtype)
21
- model.to(device)
22
- model.eval()
23
-
24
- # ============================
25
- # FONCTION D'ENCODAGE
26
- # ============================
27
 
28
- def encode_text(prompt: str):
29
- if not prompt or not prompt.strip():
30
- raise gr.Error("Le prompt ne peut pas être vide.")
 
 
 
31
 
32
- with torch.no_grad():
33
- inputs = tokenizer(
34
- prompt,
35
- return_tensors="pt",
36
- truncation=True,
37
- max_length=256
38
- ).to(device)
39
 
40
- outputs = model(**inputs)
 
 
 
 
41
 
42
- # pooling simple : moyenne sur la séquence
43
- last_hidden = outputs.last_hidden_state # [1, seq_len, hidden]
44
- emb = last_hidden.mean(dim=1).squeeze(0) # [hidden]
45
 
46
- # sauvegarde dans un fichier temporaire
47
- fd, path = tempfile.mkstemp(suffix=".pt")
48
- os.close(fd)
49
- torch.save(emb.cpu(), path)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # IMPORTANT : on renvoie le chemin du fichier
52
  return path
53
 
54
- # ============================
55
- # INTERFACE GRADIO
56
- # ============================
57
-
58
  demo = gr.Interface(
59
  fn=encode_text,
60
- inputs=gr.Textbox(label="Prompt", placeholder="Texte à encoder..."),
61
- outputs=gr.File(label="Fichier .pt des embeddings"),
62
- title="Text EncoderCPU",
63
- description="Encode un texte en vecteur et renvoie un fichier .pt (PyTorch tensor).",
64
  )
65
 
66
- # API name pour gradio_client : /encode_text
67
  demo.launch()
 
1
  import torch
2
  import gradio as gr
3
+ from diffusers import Flux2Pipeline
4
+ import tempfile, os
 
 
 
 
 
 
 
 
 
 
5
 
6
+ REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
7
  device = "cpu"
8
+ dtype = torch.float32 # CPU-safe
 
 
 
 
 
 
 
 
 
9
 
10
+ # On charge la pipeline mais on supprime tout sauf le text encoder
11
+ pipe = Flux2Pipeline.from_pretrained(
12
+ REPO_ID,
13
+ torch_dtype=dtype,
14
+ low_cpu_mem_usage=True,
15
+ )
16
 
17
+ # On supprime tout ce qui n'est PAS le text encoder
18
+ pipe.transformer = None
19
+ pipe.vae = None
20
+ pipe.scheduler = None
 
 
 
21
 
22
+ # On garde :
23
+ # - pipe.tokenizer
24
+ # - pipe.text_encoder
25
+ # - pipe.text_encoder_2 (si présent)
26
+ # - encode_prompt()
27
 
28
+ pipe.to(device)
 
 
29
 
30
+ @torch.no_grad()
31
+ def encode_text(prompt: str):
32
+ if not prompt.strip():
33
+ raise gr.Error("Prompt vide")
34
+
35
+ # encode_prompt = embeddings EXACTS attendus par FLUX2
36
+ prompt_embeds, _, _ = pipe.encode_prompt(
37
+ prompt=prompt,
38
+ device=device,
39
+ num_images_per_prompt=1,
40
+ do_classifier_free_guidance=False,
41
+ )
42
+
43
+ # Sauvegarde dans un fichier .pt
44
+ fd, path = tempfile.mkstemp(suffix=".pt")
45
+ os.close(fd)
46
+ torch.save(prompt_embeds.cpu(), path)
47
 
 
48
  return path
49
 
 
 
 
 
50
  demo = gr.Interface(
51
  fn=encode_text,
52
+ inputs=gr.Textbox(label="Prompt"),
53
+ outputs=gr.File(label="Embeddings FLUX2 (.pt)"),
54
+ title="FLUX.2 KleinText Encoder Officiel",
55
+ description="Renvoie les embeddings EXACTS que FLUX2 Klein attend.",
56
  )
57
 
 
58
  demo.launch()