groot-inosh-v3 / huggingface_handler.py
vasu24's picture
Upload huggingface_handler.py with huggingface_hub
59125f8 verified
"""
iNosh AI - Hugging Face Inference Handler
This file defines how to load and run the model on HF Inference Endpoints
"""
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model and tokenizer
path: Path to the model files (HF will provide this)
"""
# Load base model
base_model_name = "unsloth/Llama-3.2-1B-Instruct"
print(f"Loading tokenizer from {base_model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print(f"Loading base model from {base_model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_4bit=True, # Use 4-bit quantization for faster loading
)
# Load LoRA adapter from the uploaded files
print(f"Loading LoRA adapter from {path}...")
self.model = PeftModel.from_pretrained(
self.model,
path,
)
# Merge for faster inference
print("Merging adapter with base model...")
self.model = self.model.merge_and_unload()
self.model.eval()
print("iNosh AI loaded successfully!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Handle inference requests
Args:
data: {"inputs": "User message here"}
Returns:
[{"generated_text": "Response here"}]
"""
# Extract user message
inputs = data.pop("inputs", data)
user_message = inputs if isinstance(inputs, str) else inputs.get("message", "")
# System prompt (matches training)
system_prompt = """You are iNosh AI, a smart kitchen assistant that helps with pantry management and meal planning.
Target Market: Australia & New Zealand (Multicultural)
Your capabilities:
- Manage pantry items (expiry tracking, low stock alerts, barcode scanning)
- Create shopping lists (store-specific: Woolworths, Coles, Countdown, etc.)
- Suggest recipes (15 cuisines, nutrition-focused, dietary restrictions)
- Plan meals (weekly, budget-aware, nutrition-optimized)
- Track fitness (Apple Health, Google Fit integration)
- Log restaurant meals (AU/NZ chains)
- Scan barcodes (instant nutrition lookup)
- Plan kids meals (school lunch requirements)
CRITICAL RESPONSE RULES:
1. For action requests, respond with valid JSON
2. For general conversation, respond naturally without JSON
3. Always respect dietary restrictions (no pork in halal, no meat in vegan, etc.)
4. Use metric units (g, kg, ml, L) - AU/NZ standard
5. Price estimates in AUD/NZD
6. Include nutrition data when relevant (calories, protein, carbs, fat)
7. Suggest recipes from available pantry items when possible
JSON Action Formats:
- Pantry: {"action": "add_pantry", "item": {...}}
- Shopping: {"action": "create_list", "list": {...}}
- Recipes: {"action": "suggest_recipes", "recipes": [...]}
- Meal Plan: {"action": "create_meal_plan", "plan": {...}}
- Fitness: {"action": "log_workout", "workout": {...}}
- Restaurant: {"action": "log_restaurant", "meal": {...}}
- Barcode: {"action": "lookup_barcode", "product": {...}}
Tone: Professional and helpful. Provide clear, concise responses."""
# Format using Llama chat template
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=500,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract assistant response (after the prompt)
assistant_response = full_response[len(prompt):].strip()
return [{"generated_text": assistant_response}]