istiak101 commited on
Commit
69f1dd8
·
verified ·
1 Parent(s): d3923ff

Update src/prediction.py

Browse files
Files changed (1) hide show
  1. src/prediction.py +4 -2
src/prediction.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
  from textwrap import dedent
3
  from build_db_index import query_vectorstore
4
  from langdetect import detect
@@ -88,7 +88,9 @@ class Prediction:
88
  return prediction
89
 
90
  def configure_model(self, model_name):
91
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
92
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
93
  # self.tokenizer.chat_template = """{% for message in messages %}
94
  # {% if message['role'] == 'system' -%}
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoConfig
2
  from textwrap import dedent
3
  from build_db_index import query_vectorstore
4
  from langdetect import detect
 
88
  return prediction
89
 
90
  def configure_model(self, model_name):
91
+ config = AutoConfig.from_pretrained(model_name)
92
+ config.model_type = "qwen2"
93
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
94
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
95
  # self.tokenizer.chat_template = """{% for message in messages %}
96
  # {% if message['role'] == 'system' -%}