VcRlAgent commited on
Commit
6bace28
·
1 Parent(s): 6efd79e
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -52,27 +52,41 @@ def load_llamaindex_stack(model_id: str, max_new_tokens: int, temperature: float
52
  tok = AutoTokenizer.from_pretrained(model_id)
53
  mdl = AutoModelForSeq2SeqLM.from_pretrained(model_id)
54
  text2text = pipeline(
55
- task="text2text-generation",
56
  model=mdl,
57
  tokenizer=tok,
58
  max_new_tokens=max_new_tokens,
59
  temperature=float(temperature)
60
  )
61
  """
 
62
 
63
  # Wrap the same tiny HF model for LlamaIndex
64
 
65
- llm = HuggingFaceLLM(
 
 
 
 
 
 
 
66
  model_name=model_id,
67
- tokenizer_name=model_id,
68
- model_cls=AutoModelForSeq2SeqLM,
69
  context_window=2048,
70
  generate_kwargs={"max_new_tokens": max_new_tokens, "temperature": temperature},
71
  device_map="cpu",
72
- )
73
-
74
- #llm = HuggingFaceLLM(pipeline=text2text)
75
-
 
 
 
 
 
 
76
  Settings.embed_model = embed
77
  Settings.llm = llm
78
 
 
52
  tok = AutoTokenizer.from_pretrained(model_id)
53
  mdl = AutoModelForSeq2SeqLM.from_pretrained(model_id)
54
  text2text = pipeline(
55
+ "text2text-generation",
56
  model=mdl,
57
  tokenizer=tok,
58
  max_new_tokens=max_new_tokens,
59
  temperature=float(temperature)
60
  )
61
  """
62
+ #llm = HuggingFaceLLM(pipeline=text2text)
63
 
64
  # Wrap the same tiny HF model for LlamaIndex
65
 
66
+ config = AutoConfig.from_pretrained(model_id)
67
+ if config.model_type in ["t5", "mt5", "bart", "mbart", "pegasus", "marian", "prophetnet"]:
68
+ task = "text2text-generation" # encoder-decoder / seq2seq
69
+ else:
70
+ task = "text-generation"
71
+
72
+ try:
73
+ llm = HuggingFaceLLM(
74
  model_name=model_id,
75
+ tokenizer_name=model_id,
76
+ task=task,
77
  context_window=2048,
78
  generate_kwargs={"max_new_tokens": max_new_tokens, "temperature": temperature},
79
  device_map="cpu",
80
+ )
81
+ except TypeError:
82
+ llm = HuggingFaceLLM(
83
+ model_name=model_id,
84
+ tokenizer_name=model_id,
85
+ context_window=2048,
86
+ generate_kwargs={"max_new_tokens": max_new_tokens, "temperature": float(temperature)},
87
+ device_map="cpu",
88
+ )
89
+
90
  Settings.embed_model = embed
91
  Settings.llm = llm
92