Gowtham122 commited on
Commit
ed18722
·
verified ·
1 Parent(s): b242263

Update app/models.py

Browse files
Files changed (1) hide show
  1. app/models.py +3 -3
app/models.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  from pathlib import Path
4
  from typing import Optional
5
  from pydantic import BaseModel
6
- from transformers import AutoTokenizer, AutoConfig, AutoModel
7
 
8
  import torch
9
 
@@ -38,7 +38,7 @@ class DataLocation(BaseModel):
38
  self.cloud_uri, token=AUTH_TOKEN
39
  )
40
 
41
- model = AutoModel.from_pretrained(
42
  self.cloud_uri, token=AUTH_TOKEN
43
  )
44
  # Save the model and tokenizer locally
@@ -73,7 +73,7 @@ class QAModel:
73
 
74
  # Load the tokenizer and model from the local path
75
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
76
- self.model = AutoModel.from_pretrained(model_path,return_dict=True).to(self.device)
77
  logger.info(f"Loaded QA model: {self.model_name}")
78
 
79
  def inference_qa(self, context: str, question: str):
 
3
  from pathlib import Path
4
  from typing import Optional
5
  from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
7
 
8
  import torch
9
 
 
38
  self.cloud_uri, token=AUTH_TOKEN
39
  )
40
 
41
+ model = AutoModelForQuestionAnswering.from_pretrained(
42
  self.cloud_uri, token=AUTH_TOKEN
43
  )
44
  # Save the model and tokenizer locally
 
73
 
74
  # Load the tokenizer and model from the local path
75
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
76
+ self.model = AutoModelForQuestionAnswering.from_pretrained(model_path,return_dict=True).to(self.device)
77
  logger.info(f"Loaded QA model: {self.model_name}")
78
 
79
  def inference_qa(self, context: str, question: str):