chaima01 commited on
Commit
214729c
·
verified ·
1 Parent(s): 9bdf998

correct the output format

Browse files
Files changed (1) hide show
  1. handler.py +9 -7
handler.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Any
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
3
  import torch
4
 
@@ -6,30 +6,32 @@ class EndpointHandler:
6
  def __init__(self, path=""):
7
  # Load model and tokenizer
8
  self.tokenizer = AutoTokenizer.from_pretrained(path)
9
- self.model = AutoModelForCausalLM.from_pretrained(path)
10
  self.pipeline = TextGenerationPipeline(
11
  model=self.model,
12
  tokenizer=self.tokenizer,
13
  device=0 if torch.cuda.is_available() else -1
14
  )
15
 
16
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
  prompt_input = data.get("inputs", "")
18
  vibe = data.get("vibe", "Open to All Paths") # Default fallback
19
 
20
- # Construct Camino-aware prompt with tone and identity
21
  full_prompt = (
22
  f"#### Human (Vibe: {vibe}): {prompt_input.strip()}\n"
23
  f"#### Assistant (Vela - your Camino companion):"
24
  )
25
 
26
- # Default generation params, override if provided
27
  generation_args = data.get("parameters", {})
28
  generation_args.setdefault("max_new_tokens", 1024)
29
  generation_args.setdefault("temperature", 0.2)
30
  generation_args.setdefault("top_p", 0.95)
31
  generation_args.setdefault("do_sample", True)
32
 
33
- # Run generation
34
  outputs = self.pipeline(full_prompt, **generation_args)
35
- return {"generated_text": outputs[0]["generated_text"]}
 
 
 
1
+ from typing import Dict, Any, List
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
3
  import torch
4
 
 
6
  def __init__(self, path=""):
7
  # Load model and tokenizer
8
  self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.float16)
10
  self.pipeline = TextGenerationPipeline(
11
  model=self.model,
12
  tokenizer=self.tokenizer,
13
  device=0 if torch.cuda.is_available() else -1
14
  )
15
 
16
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
17
  prompt_input = data.get("inputs", "")
18
  vibe = data.get("vibe", "Open to All Paths") # Default fallback
19
 
20
+ # Construct Camino-aware prompt
21
  full_prompt = (
22
  f"#### Human (Vibe: {vibe}): {prompt_input.strip()}\n"
23
  f"#### Assistant (Vela - your Camino companion):"
24
  )
25
 
26
+ # Default generation params
27
  generation_args = data.get("parameters", {})
28
  generation_args.setdefault("max_new_tokens", 1024)
29
  generation_args.setdefault("temperature", 0.2)
30
  generation_args.setdefault("top_p", 0.95)
31
  generation_args.setdefault("do_sample", True)
32
 
33
+ # Generate response
34
  outputs = self.pipeline(full_prompt, **generation_args)
35
+
36
+ # Return in correct format
37
+ return [{"generated_text": outputs[0]["generated_text"]}]