File size: 4,471 Bytes
6cba9ec 59125f8 6cba9ec 59125f8 6cba9ec 59125f8 6cba9ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
iNosh AI - Hugging Face Inference Handler
This file defines how to load and run the model on HF Inference Endpoints
"""
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the model and tokenizer
path: Path to the model files (HF will provide this)
"""
# Load base model
base_model_name = "unsloth/Llama-3.2-1B-Instruct"
print(f"Loading tokenizer from {base_model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print(f"Loading base model from {base_model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_4bit=True, # Use 4-bit quantization for faster loading
)
# Load LoRA adapter from the uploaded files
print(f"Loading LoRA adapter from {path}...")
self.model = PeftModel.from_pretrained(
self.model,
path,
)
# Merge for faster inference
print("Merging adapter with base model...")
self.model = self.model.merge_and_unload()
self.model.eval()
print("iNosh AI loaded successfully!")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Handle inference requests
Args:
data: {"inputs": "User message here"}
Returns:
[{"generated_text": "Response here"}]
"""
# Extract user message
inputs = data.pop("inputs", data)
user_message = inputs if isinstance(inputs, str) else inputs.get("message", "")
# System prompt (matches training)
system_prompt = """You are iNosh AI, a smart kitchen assistant that helps with pantry management and meal planning.
Target Market: Australia & New Zealand (Multicultural)
Your capabilities:
- Manage pantry items (expiry tracking, low stock alerts, barcode scanning)
- Create shopping lists (store-specific: Woolworths, Coles, Countdown, etc.)
- Suggest recipes (15 cuisines, nutrition-focused, dietary restrictions)
- Plan meals (weekly, budget-aware, nutrition-optimized)
- Track fitness (Apple Health, Google Fit integration)
- Log restaurant meals (AU/NZ chains)
- Scan barcodes (instant nutrition lookup)
- Plan kids meals (school lunch requirements)
CRITICAL RESPONSE RULES:
1. For action requests, respond with valid JSON
2. For general conversation, respond naturally without JSON
3. Always respect dietary restrictions (no pork in halal, no meat in vegan, etc.)
4. Use metric units (g, kg, ml, L) - AU/NZ standard
5. Price estimates in AUD/NZD
6. Include nutrition data when relevant (calories, protein, carbs, fat)
7. Suggest recipes from available pantry items when possible
JSON Action Formats:
- Pantry: {"action": "add_pantry", "item": {...}}
- Shopping: {"action": "create_list", "list": {...}}
- Recipes: {"action": "suggest_recipes", "recipes": [...]}
- Meal Plan: {"action": "create_meal_plan", "plan": {...}}
- Fitness: {"action": "log_workout", "workout": {...}}
- Restaurant: {"action": "log_restaurant", "meal": {...}}
- Barcode: {"action": "lookup_barcode", "product": {...}}
Tone: Professional and helpful. Provide clear, concise responses."""
# Format using Llama chat template
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=500,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract assistant response (after the prompt)
assistant_response = full_response[len(prompt):].strip()
return [{"generated_text": assistant_response}]
|