Danna8 commited on
Commit
d1f0c49
·
verified ·
1 Parent(s): bc3f3f1

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +58 -0
handler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from typing import Dict, List, Any
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path: str = ""):
8
+ """
9
+ Initialize the model and tokenizer.
10
+ :param path: Path to the model repository (not used directly since we load from Hugging Face Hub).
11
+ """
12
+ # Define the base model and adapter model names
13
+ self.base_model_name = "mistralai/Mistral-7B-Instruct-v0.3"
14
+ self.adapter_model_name = "Danna8/MistralF"
15
+
16
+ # Load the tokenizer
17
+ self.tokenizer = AutoTokenizer.from_pretrained(self.adapter_model_name)
18
+
19
+ # Load the base model with optimizations
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ self.base_model_name,
22
+ torch_dtype=torch.float16, # Use FP16 for efficiency
23
+ device_map="auto" # Automatically map to GPU
24
+ )
25
+
26
+ # Load the adapter
27
+ self.model.load_adapter(self.adapter_model_name)
28
+ self.model.set_active_adapters("default") # Adjust the adapter name if needed
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
31
+ """
32
+ Handle inference requests.
33
+ :param data: Input data containing the text to process.
34
+ :return: List of generated outputs.
35
+ """
36
+ # Extract the input text from the request
37
+ inputs = data.get("inputs", "")
38
+ if not inputs:
39
+ return [{"error": "No input provided"}]
40
+
41
+ # Tokenize the input
42
+ tokenized_inputs = self.tokenizer(inputs, return_tensors="pt").to("cuda")
43
+
44
+ # Generate output
45
+ outputs = self.model.generate(
46
+ **tokenized_inputs,
47
+ max_new_tokens=50,
48
+ do_sample=True,
49
+ top_p=0.95,
50
+ temperature=0.7,
51
+ pad_token_id=self.tokenizer.eos_token_id # Ensure proper padding
52
+ )
53
+
54
+ # Decode the output
55
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+
57
+ # Return the result in the expected format
58
+ return [{"generated_text": generated_text}]