lea97338 commited on
Commit
dd4f494
·
verified ·
1 Parent(s): 9fa52ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -8,7 +8,7 @@ import spaces
8
 
9
  import torch
10
  import gradio as gr
11
- from transformers import Mistral3ForConditionalGeneration, AutoProcessor
12
 
13
  from dd import encode_prompt
14
 
@@ -34,13 +34,13 @@ DTYPE = torch.bfloat16
34
  logger.info("Loading models...")
35
 
36
  t0 = time.time()
37
- text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
38
  TEXT_ENCODER_ID,
39
  dtype=DTYPE,
40
  ).to("cpu")
41
 
42
  t1 = time.time()
43
- tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
44
  logger.info("Loaded tokenizer in %.2fs", time.time() - t1)
45
 
46
  torch.set_grad_enabled(False)
@@ -132,7 +132,7 @@ with gr.Blocks(title="Mistral Text Encoder") as demo:
132
 
133
  encode_btn.click(
134
  fn=encode_text,
135
- inputs=[prompt_input],
136
  outputs=[output_file, status_output],
137
  )
138
 
 
8
 
9
  import torch
10
  import gradio as gr
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
  from dd import encode_prompt
14
 
 
34
  logger.info("Loading models...")
35
 
36
  t0 = time.time()
37
+ text_encoder = AutoModelForCausalLM.from_pretrained(
38
  TEXT_ENCODER_ID,
39
  dtype=DTYPE,
40
  ).to("cpu")
41
 
42
  t1 = time.time()
43
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
44
  logger.info("Loaded tokenizer in %.2fs", time.time() - t1)
45
 
46
  torch.set_grad_enabled(False)
 
132
 
133
  encode_btn.click(
134
  fn=encode_text,
135
+ inputs=[{"role": "user",prompt_input}],
136
  outputs=[output_file, status_output],
137
  )
138