TymaaHammouda commited on
Commit
c3dacfc
·
verified ·
1 Parent(s): 0948bff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -45,8 +45,8 @@ def build_prompt(req1, req2, prompt_type="zero-shot"):
45
  # -----------------------------
46
  @app.on_event("startup")
47
  def load_models():
48
- print("Loading DeepSeek model into memory...")
49
- deepseek_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
50
  app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
51
  app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
52
  app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
@@ -68,16 +68,12 @@ def run_gpt4(req1, req2, prompt_type, api_key):
68
  )
69
  return completion.choices[0].message.content.strip()
70
 
 
71
  def run_deepseek(req1, req2, prompt_type):
72
  tokenizer = app.state.deepseek_tokenizer
73
  model = app.state.deepseek_model
74
  prompt = build_prompt(req1, req2, prompt_type)
75
- inputs = tokenizer(
76
- [prompt],
77
- return_tensors="pt",
78
- padding=True,
79
- truncation=True
80
- )
81
  outputs = model.generate(
82
  input_ids=inputs.input_ids,
83
  attention_mask=inputs.attention_mask,
 
45
  # -----------------------------
46
  @app.on_event("startup")
47
  def load_models():
48
+ print("Loading smaller DeepSeek model into memory...")
49
+ deepseek_name = "deepseek-ai/deepseek-vl2-small" # smaller model
50
  app.state.deepseek_tokenizer = AutoTokenizer.from_pretrained(deepseek_name)
51
  app.state.deepseek_tokenizer.pad_token = app.state.deepseek_tokenizer.eos_token
52
  app.state.deepseek_model = AutoModelForCausalLM.from_pretrained(
 
68
  )
69
  return completion.choices[0].message.content.strip()
70
 
71
+
72
  def run_deepseek(req1, req2, prompt_type):
73
  tokenizer = app.state.deepseek_tokenizer
74
  model = app.state.deepseek_model
75
  prompt = build_prompt(req1, req2, prompt_type)
76
+ inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
 
 
 
 
 
77
  outputs = model.generate(
78
  input_ids=inputs.input_ids,
79
  attention_mask=inputs.attention_mask,