|
|
""" |
|
|
Hugging Face Inference API endpoint for mbaza-model |
|
|
This file exposes the model for API calls via HF Inference API. |
|
|
""" |
|
|
from typing import Dict, List, Any |
|
|
from assistant import get_assistant |
|
|
|
|
|
|
|
|
assistant = get_assistant() |
|
|
|
|
|
|
|
|
def model(inputs: List[str]) -> Dict[str, Any]: |
|
|
""" |
|
|
Main inference function called by Hugging Face Inference API. |
|
|
|
|
|
Args: |
|
|
inputs: List of input strings (prompts/queries) |
|
|
|
|
|
Returns: |
|
|
Dict with model response |
|
|
""" |
|
|
if not inputs or not isinstance(inputs, list): |
|
|
return {"error": "Invalid input format. Expected list of strings."} |
|
|
|
|
|
prompt = inputs[0] |
|
|
user_id = inputs[1] if len(inputs) > 1 else "api_user" |
|
|
|
|
|
try: |
|
|
result = assistant.handle_query(user_id, prompt) |
|
|
return result |
|
|
except Exception as e: |
|
|
return { |
|
|
"error": f"Processing failed: {str(e)}", |
|
|
"text": "An error occurred while processing your request." |
|
|
} |
|
|
|
|
|
|
|
|
def predict(prompt: str, user_id: str = "api_user") -> Dict[str, Any]: |
|
|
""" |
|
|
Alternative prediction function with explicit parameters. |
|
|
|
|
|
Args: |
|
|
prompt: User query or greeting |
|
|
user_id: Optional user identifier for context tracking |
|
|
|
|
|
Returns: |
|
|
Dict with response, intent, and any matched data |
|
|
""" |
|
|
try: |
|
|
result = assistant.handle_query(user_id, prompt) |
|
|
return result |
|
|
except Exception as e: |
|
|
return { |
|
|
"error": f"Processing failed: {str(e)}", |
|
|
"text": "An error occurred while processing your request." |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_queries = [ |
|
|
"Mwaramutse neza", |
|
|
"Ibihano by'ubujura ni ibihe?", |
|
|
"Kwinjira aho umuntu atuye bitemewe namategeko", |
|
|
"What are the laws about corruption?" |
|
|
] |
|
|
|
|
|
print("Testing inference.py locally...\n") |
|
|
for query in test_queries: |
|
|
print(f"Query: {query}") |
|
|
response = model([query]) |
|
|
print(f"Response: {response.get('text', response)}") |
|
|
print("-" * 80) |
|
|
|