Spaces:
Sleeping
Sleeping
Update model_service.py
Browse files- model_service.py +106 -70
model_service.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
-
# model_service.py
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
from fastapi import HTTPException
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from drive_service import DriveService # Import DriveService
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
|
@@ -12,106 +15,139 @@ class ModelService:
|
|
| 12 |
self.loaded_models = {}
|
| 13 |
self.drive_service = DriveService()
|
| 14 |
|
| 15 |
-
|
| 16 |
-
@app.post("/load-model/")
|
| 17 |
-
async def load_model(request: LoadModelRequest):
|
| 18 |
"""Load a model from Google Drive."""
|
| 19 |
try:
|
| 20 |
-
logger.info(f"Loading model: {
|
| 21 |
-
|
| 22 |
# Download model files from Google Drive
|
| 23 |
logger.info("Downloading model files from Google Drive...")
|
| 24 |
-
drive_service.download_model_files_from_subfolder(
|
| 25 |
parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
|
| 26 |
-
subfolder_name=
|
| 27 |
)
|
| 28 |
-
|
| 29 |
# Load the downloaded model
|
| 30 |
-
model_path = os.path.join(BASE_MODEL_PATH,
|
| 31 |
logger.info(f"Model path: {model_path}")
|
| 32 |
-
|
| 33 |
# Initialize embeddings and load vector store
|
| 34 |
logger.info("Initializing embeddings...")
|
| 35 |
embeddings = GoogleGenerativeAIEmbeddings(
|
| 36 |
model="models/embedding-001",
|
| 37 |
google_api_key=os.getenv("GOOGLE_API_KEY")
|
| 38 |
)
|
| 39 |
-
|
| 40 |
logger.info("Loading FAISS vector store...")
|
| 41 |
vector_store = FAISS.load_local(
|
| 42 |
model_path,
|
| 43 |
embeddings,
|
| 44 |
allow_dangerous_deserialization=True
|
| 45 |
)
|
| 46 |
-
|
| 47 |
# Configure the QA chain
|
| 48 |
logger.info("Configuring QA chain...")
|
| 49 |
-
chain =
|
| 50 |
-
|
| 51 |
# Store the loaded model in memory
|
| 52 |
-
|
| 53 |
"vector_store": vector_store,
|
| 54 |
"chain": chain
|
| 55 |
}
|
| 56 |
-
|
| 57 |
-
logger.info(f"Model '{
|
| 58 |
return {
|
| 59 |
"status": "success",
|
| 60 |
-
"message": f"Model '{
|
| 61 |
}
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"Error loading model: {str(e)}")
|
| 64 |
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
|
| 65 |
|
| 66 |
-
# def load_model(self, model_name: str, temperature: float = 0.7):
|
| 67 |
-
# """Load a model from Google Drive."""
|
| 68 |
-
# try:
|
| 69 |
-
# logger.info(f"Loading model: {model_name} with temperature: {temperature}")
|
| 70 |
-
|
| 71 |
-
# # Download model files from Google Drive
|
| 72 |
-
# logger.info("Downloading model files from Google Drive...")
|
| 73 |
-
# self.drive_service.download_model_files_from_subfolder(
|
| 74 |
-
# parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
|
| 75 |
-
# subfolder_name=model_name
|
| 76 |
-
# )
|
| 77 |
-
|
| 78 |
-
# # Load the downloaded model
|
| 79 |
-
# model_path = os.path.join(BASE_MODEL_PATH, model_name)
|
| 80 |
-
# logger.info(f"Model path: {model_path}")
|
| 81 |
-
|
| 82 |
-
# # Initialize embeddings and load vector store
|
| 83 |
-
# logger.info("Initializing embeddings...")
|
| 84 |
-
# embeddings = GoogleGenerativeAIEmbeddings(
|
| 85 |
-
# model="models/embedding-001",
|
| 86 |
-
# google_api_key=os.getenv("GOOGLE_API_KEY")
|
| 87 |
-
# )
|
| 88 |
-
|
| 89 |
-
# logger.info("Loading FAISS vector store...")
|
| 90 |
-
# vector_store = FAISS.load_local(
|
| 91 |
-
# model_path,
|
| 92 |
-
# embeddings,
|
| 93 |
-
# allow_dangerous_deserialization=True
|
| 94 |
-
# )
|
| 95 |
-
|
| 96 |
-
# # Configure the QA chain
|
| 97 |
-
# logger.info("Configuring QA chain...")
|
| 98 |
-
# chain = self.configure_chain(temperature)
|
| 99 |
-
|
| 100 |
-
# # Store the loaded model in memory
|
| 101 |
-
# self.loaded_models[model_name] = {
|
| 102 |
-
# "vector_store": vector_store,
|
| 103 |
-
# "chain": chain
|
| 104 |
-
# }
|
| 105 |
-
|
| 106 |
-
# logger.info(f"Model '{model_name}' loaded successfully")
|
| 107 |
-
# return {
|
| 108 |
-
# "status": "success",
|
| 109 |
-
# "message": f"Model '{model_name}' loaded successfully"
|
| 110 |
-
# }
|
| 111 |
-
# except Exception as e:
|
| 112 |
-
# logger.error(f"Error loading model: {str(e)}")
|
| 113 |
-
# raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
|
| 114 |
-
|
| 115 |
def chat_with_model(self, model_name: str, question: str):
|
| 116 |
"""Generate a response using the loaded model."""
|
| 117 |
-
if model_name not in self.loaded_models:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
from fastapi import HTTPException
|
| 4 |
+
from langchain_community.vectorstores import FAISS
|
| 5 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
|
| 6 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 7 |
+
from langchain.prompts import PromptTemplate
|
| 8 |
+
from config import BASE_MODEL_PATH, GOOGLE_DRIVE_FOLDER_ID
|
| 9 |
from drive_service import DriveService # Import DriveService
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
|
|
|
| 15 |
self.loaded_models = {}
|
| 16 |
self.drive_service = DriveService()
|
| 17 |
|
| 18 |
+
def load_model(self, model_name: str, temperature: float = 0.7):
|
|
|
|
|
|
|
| 19 |
"""Load a model from Google Drive."""
|
| 20 |
try:
|
| 21 |
+
logger.info(f"Loading model: {model_name} with temperature: {temperature}")
|
| 22 |
+
|
| 23 |
# Download model files from Google Drive
|
| 24 |
logger.info("Downloading model files from Google Drive...")
|
| 25 |
+
self.drive_service.download_model_files_from_subfolder(
|
| 26 |
parent_folder_id=GOOGLE_DRIVE_FOLDER_ID,
|
| 27 |
+
subfolder_name=model_name
|
| 28 |
)
|
| 29 |
+
|
| 30 |
# Load the downloaded model
|
| 31 |
+
model_path = os.path.join(BASE_MODEL_PATH, model_name)
|
| 32 |
logger.info(f"Model path: {model_path}")
|
| 33 |
+
|
| 34 |
# Initialize embeddings and load vector store
|
| 35 |
logger.info("Initializing embeddings...")
|
| 36 |
embeddings = GoogleGenerativeAIEmbeddings(
|
| 37 |
model="models/embedding-001",
|
| 38 |
google_api_key=os.getenv("GOOGLE_API_KEY")
|
| 39 |
)
|
| 40 |
+
|
| 41 |
logger.info("Loading FAISS vector store...")
|
| 42 |
vector_store = FAISS.load_local(
|
| 43 |
model_path,
|
| 44 |
embeddings,
|
| 45 |
allow_dangerous_deserialization=True
|
| 46 |
)
|
| 47 |
+
|
| 48 |
# Configure the QA chain
|
| 49 |
logger.info("Configuring QA chain...")
|
| 50 |
+
chain = self.configure_chain(temperature)
|
| 51 |
+
|
| 52 |
# Store the loaded model in memory
|
| 53 |
+
self.loaded_models[model_name] = {
|
| 54 |
"vector_store": vector_store,
|
| 55 |
"chain": chain
|
| 56 |
}
|
| 57 |
+
|
| 58 |
+
logger.info(f"Model '{model_name}' loaded successfully")
|
| 59 |
return {
|
| 60 |
"status": "success",
|
| 61 |
+
"message": f"Model '{model_name}' loaded successfully"
|
| 62 |
}
|
| 63 |
except Exception as e:
|
| 64 |
logger.error(f"Error loading model: {str(e)}")
|
| 65 |
raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}")
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def chat_with_model(self, model_name: str, question: str):
|
| 68 |
"""Generate a response using the loaded model."""
|
| 69 |
+
if model_name not in self.loaded_models:
|
| 70 |
+
raise HTTPException(
|
| 71 |
+
status_code=404,
|
| 72 |
+
detail=f"Model '{model_name}' not loaded. Please load it first."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
model_data = self.loaded_models[model_name]
|
| 77 |
+
docs = model_data["vector_store"].similarity_search(question)
|
| 78 |
+
response = model_data["chain"](
|
| 79 |
+
{
|
| 80 |
+
"input_documents": docs,
|
| 81 |
+
"question": question
|
| 82 |
+
},
|
| 83 |
+
return_only_outputs=True
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"status": "success",
|
| 88 |
+
"response": response["output_text"]
|
| 89 |
+
}
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Error generating response: {str(e)}")
|
| 92 |
+
raise HTTPException(
|
| 93 |
+
status_code=500,
|
| 94 |
+
detail=f"Failed to generate response: {str(e)}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def configure_chain(self, temperature: float):
|
| 98 |
+
"""Configure the QA chain with the updated prompt template."""
|
| 99 |
+
prompt_template = """
|
| 100 |
+
You are an AI assistant for SBBU SBA university. Your task is to provide clear, accurate, and helpful responses based on the context provided, as well as to respond to basic greetings and conversational queries. However, if the user makes inappropriate or offensive remarks, you should respond politely and professionally, redirecting the conversation back to helpful topics.
|
| 101 |
+
|
| 102 |
+
Instructions:
|
| 103 |
+
1. **Greeting Responses**: If the user greets you (e.g., "Hello," "Hi," "Hey," "Salam," etc.), respond warmly and politely. Example responses could be:
|
| 104 |
+
- "Hello! How can I assist you today?"
|
| 105 |
+
- "Hi there! How can I help you?"
|
| 106 |
+
- "Salam! What can I do for you today?"
|
| 107 |
+
|
| 108 |
+
2. **Casual and Playful Inquiries**: If the user says something playful or informal like "I kiss you" or similar, acknowledge it politely but redirect the conversation back to the main topic. Example:
|
| 109 |
+
- "Thank you for the kind words! How can I assist you further?"
|
| 110 |
+
- "I appreciate your enthusiasm! How can I help you today?"
|
| 111 |
+
|
| 112 |
+
3. **Inappropriate or Offensive Remarks**: If the user makes inappropriate, disrespectful, or offensive comments, such as offensive language or sexually explicit remarks, respond politely but firmly, maintaining professionalism:
|
| 113 |
+
- "I strive to maintain a respectful conversation. How can I assist you with your queries?"
|
| 114 |
+
- "Let's keep the conversation respectful. How can I help you today?"
|
| 115 |
+
- "I apologize, but I cannot engage in that kind of discussion. Please ask a relevant question related to the university."
|
| 116 |
+
|
| 117 |
+
4. **Contextual Responses**:
|
| 118 |
+
- If the context contains relevant information to the question, provide a clear and direct answer.
|
| 119 |
+
- If the context only provides partial information, provide a helpful response based on available data and related details.
|
| 120 |
+
- If the context has no relevant information, respond with: "I apologize, but I don't have specific information about that. Could you please ask something else about the university?"
|
| 121 |
+
|
| 122 |
+
5. **Accuracy and Clarity**: Ensure your responses are clear, concise, and accurate. Avoid unnecessary details or over-explanation.
|
| 123 |
+
|
| 124 |
+
6. **Clarification**: If the user's question is unclear or lacks sufficient context, ask for clarification. For example:
|
| 125 |
+
- "Could you please clarify your question?"
|
| 126 |
+
- "I'm not sure I understand. Can you rephrase your question?"
|
| 127 |
+
|
| 128 |
+
Context Information:
|
| 129 |
+
---------------------
|
| 130 |
+
{context}
|
| 131 |
+
|
| 132 |
+
Question:
|
| 133 |
+
{question}
|
| 134 |
+
|
| 135 |
+
Response:
|
| 136 |
+
Provide a friendly, clear, and direct response based on the context. Always aim to be helpful, especially for greetings or casual inquiries, and suggest follow-up questions or clarifications if needed.
|
| 137 |
+
no preamble
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
model = ChatGoogleGenerativeAI(
|
| 142 |
+
model="gemini-pro",
|
| 143 |
+
temperature=temperature,
|
| 144 |
+
google_api_key=os.getenv("GOOGLE_API_KEY")
|
| 145 |
+
)
|
| 146 |
+
prompt = PromptTemplate(
|
| 147 |
+
template=prompt_template,
|
| 148 |
+
input_variables=["context", "question"]
|
| 149 |
+
)
|
| 150 |
+
return load_qa_chain(model, chain_type="stuff", prompt=prompt)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"Error configuring chain: {str(e)}")
|
| 153 |
+
raise HTTPException(status_code=500, detail="Failed to configure model chain")
|