shakeel143 commited on
Commit
d640e7d
·
verified ·
1 Parent(s): d3ef616

Update model_service.py

Browse files
Files changed (1) hide show
  1. model_service.py +96 -40
model_service.py CHANGED
@@ -14,85 +14,141 @@ class ModelService:
14
  def __init__(self):
15
  self.loaded_models = {}
16
  self.drive_service = DriveService()
17
-
18
  def load_model(self, model_name: str, temperature: float = 0.7):
19
  """Load a model from Google Drive."""
20
  try:
21
  logger.info(f"Loading model: {model_name} with temperature: {temperature}")
22
-
23
  # Download model files from Google Drive
24
  logger.info("Downloading model files from Google Drive...")
25
  self.drive_service.download_model_files_from_subfolder(
26
  parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
27
  subfolder_name=model_name
28
  )
29
-
30
  # Load the downloaded model
31
- model_path = os.path.join(BASE_MODEL_PATH, model_name)
32
  logger.info(f"Model path: {model_path}")
33
-
 
 
 
 
34
  # Initialize embeddings and load vector store
35
  logger.info("Initializing embeddings...")
36
  embeddings = GoogleGenerativeAIEmbeddings(
37
  model="models/embedding-001",
38
  google_api_key=os.getenv("GOOGLE_API_KEY")
39
  )
40
-
41
  logger.info("Loading FAISS vector store...")
42
  vector_store = FAISS.load_local(
43
- model_path,
44
  embeddings,
45
  allow_dangerous_deserialization=True
46
  )
47
-
48
  # Configure the QA chain
49
  logger.info("Configuring QA chain...")
50
  chain = self.configure_chain(temperature)
51
-
52
  # Store the loaded model in memory
53
  self.loaded_models[model_name] = {
54
  "vector_store": vector_store,
55
  "chain": chain
56
  }
57
-
58
  logger.info(f"Model '{model_name}' loaded successfully")
59
  return {
60
  "status": "success",
61
  "message": f"Model '{model_name}' loaded successfully"
62
  }
 
 
 
63
  except Exception as e:
64
  logger.error(f"Error loading model: {str(e)}")
65
  raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
66
-
67
- def chat_with_model(self, model_name: str, question: str):
68
- """Generate a response using the loaded model."""
69
- if model_name not in self.loaded_models:
70
- raise HTTPException(
71
- status_code=404,
72
- detail=f"Model '{model_name}' not loaded. Please load it first."
73
- )
74
-
75
- try:
76
- model_data = self.loaded_models[model_name]
77
- docs = model_data["vector_store"].similarity_search(question)
78
- response = model_data["chain"](
79
- {
80
- "input_documents": docs,
81
- "question": question
82
- },
83
- return_only_outputs=True
84
- )
85
-
86
- return {
87
- "status": "success",
88
- "response": response["output_text"]
89
- }
90
- except Exception as e:
91
- logger.error(f"Error generating response: {str(e)}")
92
- raise HTTPException(
93
- status_code=500,
94
- detail=f"Failed to generate response: {str(e)}"
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def configure_chain(self, temperature: float):
98
  """Configure the QA chain with the updated prompt template."""
 
14
  def __init__(self):
15
  self.loaded_models = {}
16
  self.drive_service = DriveService()
17
+
18
  def load_model(self, model_name: str, temperature: float = 0.7):
19
  """Load a model from Google Drive."""
20
  try:
21
  logger.info(f"Loading model: {model_name} with temperature: {temperature}")
22
+
23
  # Download model files from Google Drive
24
  logger.info("Downloading model files from Google Drive...")
25
  self.drive_service.download_model_files_from_subfolder(
26
  parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
27
  subfolder_name=model_name
28
  )
29
+
30
  # Load the downloaded model
31
+ model_path = os.path.join(BASE_MODEL_PATH, model_name, "faiss_index") # Add "faiss_index" to the path
32
  logger.info(f"Model path: {model_path}")
33
+
34
+ # Verify the model files exist
35
+ if not os.path.exists(os.path.join(model_path, "index.faiss")):
36
+ raise FileNotFoundError(f"FAISS index not found at {model_path}")
37
+
38
  # Initialize embeddings and load vector store
39
  logger.info("Initializing embeddings...")
40
  embeddings = GoogleGenerativeAIEmbeddings(
41
  model="models/embedding-001",
42
  google_api_key=os.getenv("GOOGLE_API_KEY")
43
  )
44
+
45
  logger.info("Loading FAISS vector store...")
46
  vector_store = FAISS.load_local(
47
+ model_path, # This path should now point to the faiss_index directory
48
  embeddings,
49
  allow_dangerous_deserialization=True
50
  )
51
+
52
  # Configure the QA chain
53
  logger.info("Configuring QA chain...")
54
  chain = self.configure_chain(temperature)
55
+
56
  # Store the loaded model in memory
57
  self.loaded_models[model_name] = {
58
  "vector_store": vector_store,
59
  "chain": chain
60
  }
61
+
62
  logger.info(f"Model '{model_name}' loaded successfully")
63
  return {
64
  "status": "success",
65
  "message": f"Model '{model_name}' loaded successfully"
66
  }
67
+ except FileNotFoundError as e:
68
+ logger.error(f"File not found error: {str(e)}")
69
+ raise HTTPException(status_code=404, detail=str(e))
70
  except Exception as e:
71
  logger.error(f"Error loading model: {str(e)}")
72
  raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
73
+
74
+ # def load_model(self, model_name: str, temperature: float = 0.7):
75
+ # """Load a model from Google Drive."""
76
+ # try:
77
+ # logger.info(f"Loading model: {model_name} with temperature: {temperature}")
78
+
79
+ # # Download model files from Google Drive
80
+ # logger.info("Downloading model files from Google Drive...")
81
+ # self.drive_service.download_model_files_from_subfolder(
82
+ # parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
83
+ # subfolder_name=model_name
84
+ # )
85
+
86
+ # # Load the downloaded model
87
+ # model_path = os.path.join(BASE_MODEL_PATH, model_name)
88
+ # logger.info(f"Model path: {model_path}")
89
+
90
+ # # Initialize embeddings and load vector store
91
+ # logger.info("Initializing embeddings...")
92
+ # embeddings = GoogleGenerativeAIEmbeddings(
93
+ # model="models/embedding-001",
94
+ # google_api_key=os.getenv("GOOGLE_API_KEY")
95
+ # )
96
+
97
+ # logger.info("Loading FAISS vector store...")
98
+ # vector_store = FAISS.load_local(
99
+ # model_path,
100
+ # embeddings,
101
+ # allow_dangerous_deserialization=True
102
+ # )
103
+
104
+ # # Configure the QA chain
105
+ # logger.info("Configuring QA chain...")
106
+ # chain = self.configure_chain(temperature)
107
+
108
+ # # Store the loaded model in memory
109
+ # self.loaded_models[model_name] = {
110
+ # "vector_store": vector_store,
111
+ # "chain": chain
112
+ # }
113
+
114
+ # logger.info(f"Model '{model_name}' loaded successfully")
115
+ # return {
116
+ # "status": "success",
117
+ # "message": f"Model '{model_name}' loaded successfully"
118
+ # }
119
+ # except Exception as e:
120
+ # logger.error(f"Error loading model: {str(e)}")
121
+ # raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
122
+
123
+ # def chat_with_model(self, model_name: str, question: str):
124
+ # """Generate a response using the loaded model."""
125
+ # if model_name not in self.loaded_models:
126
+ # raise HTTPException(
127
+ # status_code=404,
128
+ # detail=f"Model '{model_name}' not loaded. Please load it first."
129
+ # )
130
+
131
+ # try:
132
+ # model_data = self.loaded_models[model_name]
133
+ # docs = model_data["vector_store"].similarity_search(question)
134
+ # response = model_data["chain"](
135
+ # {
136
+ # "input_documents": docs,
137
+ # "question": question
138
+ # },
139
+ # return_only_outputs=True
140
+ # )
141
+
142
+ # return {
143
+ # "status": "success",
144
+ # "response": response["output_text"]
145
+ # }
146
+ # except Exception as e:
147
+ # logger.error(f"Error generating response: {str(e)}")
148
+ # raise HTTPException(
149
+ # status_code=500,
150
+ # detail=f"Failed to generate response: {str(e)}"
151
+ # )
152
 
153
  def configure_chain(self, temperature: float):
154
  """Configure the QA chain with the updated prompt template."""