Update app.py
Browse files
app.py
CHANGED
|
@@ -24,10 +24,8 @@ logger = logging.getLogger("mistral-text-encoding-gradio")
|
|
| 24 |
# ------------------------------------------------------
|
| 25 |
# Config
|
| 26 |
# ------------------------------------------------------
|
| 27 |
-
TEXT_ENCODER_ID
|
| 28 |
-
TOKENIZER_ID =
|
| 29 |
-
"TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
| 30 |
-
)
|
| 31 |
DTYPE = torch.bfloat16
|
| 32 |
|
| 33 |
# ------------------------------------------------------
|
|
@@ -40,11 +38,6 @@ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
|
| 40 |
TEXT_ENCODER_ID,
|
| 41 |
dtype=DTYPE,
|
| 42 |
).to("cpu")
|
| 43 |
-
logger.info(
|
| 44 |
-
"Loaded Mistral text encoder (%.2fs) dtype=%s device=%s",
|
| 45 |
-
time.time() - t0,
|
| 46 |
-
text_encoder.dtype,
|
| 47 |
-
)
|
| 48 |
|
| 49 |
t1 = time.time()
|
| 50 |
tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
|
|
@@ -90,7 +83,6 @@ def encode_text(prompt: str):
|
|
| 90 |
)
|
| 91 |
|
| 92 |
duration = (time.time() - t0) * 1000.0
|
| 93 |
-
|
| 94 |
logger.info(
|
| 95 |
"Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s",
|
| 96 |
duration,
|
|
@@ -116,7 +108,7 @@ def encode_text(prompt: str):
|
|
| 116 |
)
|
| 117 |
|
| 118 |
return temp_file.name, status
|
| 119 |
-
|
| 120 |
|
| 121 |
# ------------------------------------------------------
|
| 122 |
# Gradio Interface
|
|
|
|
| 24 |
# ------------------------------------------------------
|
| 25 |
# Config
|
| 26 |
# ------------------------------------------------------
|
| 27 |
+
TEXT_ENCODER_ID="Qwen/Qwen2.5-7B-Instruct-1M"
|
| 28 |
+
TOKENIZER_ID = "Qwen/Qwen2.5-7B-Instruct-1M"
|
|
|
|
|
|
|
| 29 |
DTYPE = torch.bfloat16
|
| 30 |
|
| 31 |
# ------------------------------------------------------
|
|
|
|
| 38 |
TEXT_ENCODER_ID,
|
| 39 |
dtype=DTYPE,
|
| 40 |
).to("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
t1 = time.time()
|
| 43 |
tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
duration = (time.time() - t0) * 1000.0
|
|
|
|
| 86 |
logger.info(
|
| 87 |
"Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s",
|
| 88 |
duration,
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
return temp_file.name, status
|
| 111 |
+
|
| 112 |
|
| 113 |
# ------------------------------------------------------
|
| 114 |
# Gradio Interface
|