mbaza-model / inference.py
mugwaneza's picture
Deploy Mbaza Legal AI Model with inference endpoint
fc1c893
"""
Hugging Face Inference API endpoint for mbaza-model
This file exposes the model for API calls via HF Inference API.
"""
from typing import Dict, List, Any
from assistant import get_assistant
# Initialize the assistant once
assistant = get_assistant()
def model(inputs: List[str]) -> Dict[str, Any]:
"""
Main inference function called by Hugging Face Inference API.
Args:
inputs: List of input strings (prompts/queries)
Returns:
Dict with model response
"""
if not inputs or not isinstance(inputs, list):
return {"error": "Invalid input format. Expected list of strings."}
prompt = inputs[0]
user_id = inputs[1] if len(inputs) > 1 else "api_user"
try:
result = assistant.handle_query(user_id, prompt)
return result
except Exception as e:
return {
"error": f"Processing failed: {str(e)}",
"text": "An error occurred while processing your request."
}
def predict(prompt: str, user_id: str = "api_user") -> Dict[str, Any]:
"""
Alternative prediction function with explicit parameters.
Args:
prompt: User query or greeting
user_id: Optional user identifier for context tracking
Returns:
Dict with response, intent, and any matched data
"""
try:
result = assistant.handle_query(user_id, prompt)
return result
except Exception as e:
return {
"error": f"Processing failed: {str(e)}",
"text": "An error occurred while processing your request."
}
# For testing locally
if __name__ == "__main__":
test_queries = [
"Mwaramutse neza",
"Ibihano by'ubujura ni ibihe?",
"Kwinjira aho umuntu atuye bitemewe namategeko",
"What are the laws about corruption?"
]
print("Testing inference.py locally...\n")
for query in test_queries:
print(f"Query: {query}")
response = model([query])
print(f"Response: {response.get('text', response)}")
print("-" * 80)