shakeel143 commited on
Commit
452f242
·
verified ·
1 Parent(s): acf6711

Update model_service.py

Browse files
Files changed (1) hide show
  1. model_service.py +14 -34
model_service.py CHANGED
@@ -1,6 +1,8 @@
1
  # model_service.py
2
 
3
  import os
 
 
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
6
  from langchain.chains.question_answering import load_qa_chain
@@ -8,32 +10,37 @@ from langchain.prompts import PromptTemplate
8
  from config import BASE_MODEL_PATH, GOOGLE_DRIVE_FOLDER_ID
9
  from drive_service import DriveService
10
 
 
 
11
  class ModelService:
12
  def __init__(self):
13
  self.loaded_models = {}
 
14
 
15
  def load_model(self, model_name: str, temperature: float = 0.7):
16
  """Load a model from Google Drive."""
17
  try:
18
- # Initialize Google Drive API
19
- drive_service = DriveService()
20
 
21
  # Download model files from Google Drive
22
- drive_service.download_model_files_from_subfolder(
 
23
  parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
24
  subfolder_name=model_name
25
  )
26
 
27
  # Load the downloaded model
28
  model_path = os.path.join(BASE_MODEL_PATH, model_name)
 
29
 
30
  # Initialize embeddings and load vector store
 
31
  embeddings = GoogleGenerativeAIEmbeddings(
32
  model="models/embedding-001",
33
  google_api_key=os.getenv("GOOGLE_API_KEY")
34
  )
35
 
36
- # Load the local FAISS index and vector store
37
  vector_store = FAISS.load_local(
38
  model_path,
39
  embeddings,
@@ -41,6 +48,7 @@ class ModelService:
41
  )
42
 
43
  # Configure the QA chain
 
44
  chain = self.configure_chain(temperature)
45
 
46
  # Store the loaded model in memory
@@ -49,6 +57,7 @@ class ModelService:
49
  "chain": chain
50
  }
51
 
 
52
  return {
53
  "status": "success",
54
  "message": f"Model '{model_name}' loaded successfully"
@@ -113,33 +122,4 @@ class ModelService:
113
  return load_qa_chain(model, chain_type="stuff", prompt=prompt)
114
  except Exception as e:
115
  logger.error(f"Error configuring chain: {str(e)}")
116
- raise HTTPException(status_code=500, detail="Failed to configure model chain")
117
-
118
- def chat_with_model(self, model_name: str, question: str):
119
- """Generate a response using the loaded model."""
120
- if model_name not in self.loaded_models:
121
- raise HTTPException(
122
- status_code=404,
123
- detail=f"Model '{model_name}' not loaded. Please load it first."
124
- )
125
-
126
- try:
127
- model_data = self.loaded_models[model_name]
128
- docs = model_data["vector_store"].similarity_search(question)
129
- response = model_data["chain"](
130
- {
131
- "input_documents": docs,
132
- "question": question
133
- },
134
- return_only_outputs=True
135
- )
136
-
137
- return {
138
- "status": "success",
139
- "response": response["output_text"]
140
- }
141
- except Exception as e:
142
- logger.error(f"Error generating response: {str(e)}")
143
- raise HTTPException(
144
- status_code=500,
145
- detail=f"Failed to generate response: {str(e)}")
 
1
  # model_service.py
2
 
3
  import os
4
+ import logging
5
+ from fastapi import HTTPException
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
8
  from langchain.chains.question_answering import load_qa_chain
 
10
  from config import BASE_MODEL_PATH, GOOGLE_DRIVE_FOLDER_ID
11
  from drive_service import DriveService
12
 
13
+ logger = logging.getLogger(__name__)
14
+
15
  class ModelService:
16
  def __init__(self):
17
  self.loaded_models = {}
18
+ self.drive_service = DriveService()
19
 
20
  def load_model(self, model_name: str, temperature: float = 0.7):
21
  """Load a model from Google Drive."""
22
  try:
23
+ logger.info(f"Loading model: {model_name} with temperature: {temperature}")
 
24
 
25
  # Download model files from Google Drive
26
+ logger.info("Downloading model files from Google Drive...")
27
+ self.drive_service.download_model_files_from_subfolder(
28
  parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
29
  subfolder_name=model_name
30
  )
31
 
32
  # Load the downloaded model
33
  model_path = os.path.join(BASE_MODEL_PATH, model_name)
34
+ logger.info(f"Model path: {model_path}")
35
 
36
  # Initialize embeddings and load vector store
37
+ logger.info("Initializing embeddings...")
38
  embeddings = GoogleGenerativeAIEmbeddings(
39
  model="models/embedding-001",
40
  google_api_key=os.getenv("GOOGLE_API_KEY")
41
  )
42
 
43
+ logger.info("Loading FAISS vector store...")
44
  vector_store = FAISS.load_local(
45
  model_path,
46
  embeddings,
 
48
  )
49
 
50
  # Configure the QA chain
51
+ logger.info("Configuring QA chain...")
52
  chain = self.configure_chain(temperature)
53
 
54
  # Store the loaded model in memory
 
57
  "chain": chain
58
  }
59
 
60
+ logger.info(f"Model '{model_name}' loaded successfully")
61
  return {
62
  "status": "success",
63
  "message": f"Model '{model_name}' loaded successfully"
 
122
  return load_qa_chain(model, chain_type="stuff", prompt=prompt)
123
  except Exception as e:
124
  logger.error(f"Error configuring chain: {str(e)}")
125
+ raise HTTPException(status_code=500, detail="Failed to configure model chain")