vasu24 commited on
Commit
6cba9ec
·
verified ·
1 Parent(s): 3c5c8c9

Upload huggingface_handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. huggingface_handler.py +126 -0
huggingface_handler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GROOT v3 - Hugging Face Inference Handler
3
+ This file defines how to load and run the model on HF Inference Endpoints
4
+ """
5
+
6
+ from typing import Dict, List, Any
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+ from peft import PeftModel
10
+
11
+
12
+ class EndpointHandler:
13
+ def __init__(self, path=""):
14
+ """
15
+ Initialize the model and tokenizer
16
+ path: Path to the model files (HF will provide this)
17
+ """
18
+ # Load base model
19
+ base_model_name = "unsloth/Llama-3.2-1B-Instruct"
20
+
21
+ print(f"Loading tokenizer from {base_model_name}...")
22
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
23
+
24
+ print(f"Loading base model from {base_model_name}...")
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ base_model_name,
27
+ torch_dtype=torch.float16,
28
+ device_map="auto",
29
+ load_in_4bit=True, # Use 4-bit quantization for faster loading
30
+ )
31
+
32
+ # Load LoRA adapter from the uploaded files
33
+ print(f"Loading LoRA adapter from {path}...")
34
+ self.model = PeftModel.from_pretrained(
35
+ self.model,
36
+ path,
37
+ )
38
+
39
+ # Merge for faster inference
40
+ print("Merging adapter with base model...")
41
+ self.model = self.model.merge_and_unload()
42
+
43
+ self.model.eval()
44
+ print("GROOT v3 loaded successfully!")
45
+
46
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
47
+ """
48
+ Handle inference requests
49
+
50
+ Args:
51
+ data: {"inputs": "User message here"}
52
+
53
+ Returns:
54
+ [{"generated_text": "Response here"}]
55
+ """
56
+ # Extract user message
57
+ inputs = data.pop("inputs", data)
58
+ user_message = inputs if isinstance(inputs, str) else inputs.get("message", "")
59
+
60
+ # System prompt (matches training)
61
+ system_prompt = """You are GROOT, an AI kitchen assistant for iNosh - a smart pantry and meal planning app.
62
+
63
+ Target Market: Australia & New Zealand (Multicultural)
64
+
65
+ Your capabilities:
66
+ - Manage pantry items (expiry tracking, low stock alerts, barcode scanning)
67
+ - Create shopping lists (store-specific: Woolworths, Coles, Countdown, etc.)
68
+ - Suggest recipes (15 cuisines, nutrition-focused, dietary restrictions)
69
+ - Plan meals (weekly, budget-aware, nutrition-optimized)
70
+ - Track fitness (Apple Health, Google Fit integration)
71
+ - Log restaurant meals (AU/NZ chains)
72
+ - Scan barcodes (instant nutrition lookup)
73
+ - Plan kids meals (school lunch requirements)
74
+
75
+ CRITICAL RESPONSE RULES:
76
+ 1. For action requests, respond with valid JSON
77
+ 2. For general conversation, respond naturally without JSON
78
+ 3. Always respect dietary restrictions (no pork in halal, no meat in vegan, etc.)
79
+ 4. Use metric units (g, kg, ml, L) - AU/NZ standard
80
+ 5. Price estimates in AUD/NZD
81
+ 6. Include nutrition data when relevant (calories, protein, carbs, fat)
82
+ 7. Suggest recipes from available pantry items when possible
83
+
84
+ JSON Action Formats:
85
+ - Pantry: {"action": "add_pantry", "item": {...}}
86
+ - Shopping: {"action": "create_list", "list": {...}}
87
+ - Recipes: {"action": "suggest_recipes", "recipes": [...]}
88
+ - Meal Plan: {"action": "create_meal_plan", "plan": {...}}
89
+ - Fitness: {"action": "log_workout", "workout": {...}}
90
+ - Restaurant: {"action": "log_restaurant", "meal": {...}}
91
+ - Barcode: {"action": "lookup_barcode", "product": {...}}
92
+
93
+ Tone: Professional and helpful. Provide clear, concise responses."""
94
+
95
+ # Format using Llama chat template
96
+ messages = [
97
+ {"role": "system", "content": system_prompt},
98
+ {"role": "user", "content": user_message},
99
+ ]
100
+
101
+ prompt = self.tokenizer.apply_chat_template(
102
+ messages,
103
+ tokenize=False,
104
+ add_generation_prompt=True,
105
+ )
106
+
107
+ # Tokenize
108
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
109
+
110
+ # Generate
111
+ with torch.no_grad():
112
+ outputs = self.model.generate(
113
+ **inputs,
114
+ max_new_tokens=500,
115
+ temperature=0.7,
116
+ do_sample=True,
117
+ pad_token_id=self.tokenizer.eos_token_id,
118
+ )
119
+
120
+ # Decode
121
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
122
+
123
+ # Extract assistant response (after the prompt)
124
+ assistant_response = full_response[len(prompt):].strip()
125
+
126
+ return [{"generated_text": assistant_response}]