sachiniyer commited on
Commit
2d01b34
·
verified ·
1 Parent(s): 78310b8

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. __pycache__/backend.cpython-312.pyc +0 -0
  2. backend.py +12 -8
__pycache__/backend.cpython-312.pyc CHANGED
Binary files a/__pycache__/backend.cpython-312.pyc and b/__pycache__/backend.cpython-312.pyc differ
 
backend.py CHANGED
@@ -88,15 +88,18 @@ class Inference:
88
  tokenizer = self.models[model_id]["tokenizer"]
89
  model = self.models[model_id]["model"]
90
 
91
- conversation = ""
92
  for msg in history:
93
  role = msg.get("role", "user")
94
  content = msg.get("content", "")
95
- if role == "user":
96
- conversation += f"User: {content}\n"
97
- else:
98
- conversation += f"Assistant: {content}\n"
99
- conversation += f"User: {message}\nAssistant:"
 
 
 
100
 
101
  try:
102
  inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
@@ -114,8 +117,9 @@ class Inference:
114
  )
115
  logger.info(f"Generated output shape: {outputs.shape}")
116
 
117
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
118
- response = response.split("Assistant:")[-1].strip()
 
119
  logger.info(f"Final response length: {len(response)}")
120
  logger.info(f"Response: {response}")
121
 
 
88
  tokenizer = self.models[model_id]["tokenizer"]
89
  model = self.models[model_id]["model"]
90
 
91
+ messages = []
92
  for msg in history:
93
  role = msg.get("role", "user")
94
  content = msg.get("content", "")
95
+ messages.append({"role": role, "content": content})
96
+ messages.append({"role": "user", "content": message})
97
+
98
+ conversation = tokenizer.apply_chat_template(
99
+ messages,
100
+ tokenize=False,
101
+ add_generation_prompt=True,
102
+ )
103
 
104
  try:
105
  inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
 
117
  )
118
  logger.info(f"Generated output shape: {outputs.shape}")
119
 
120
+ # Extract only the newly generated tokens (skip the input)
121
+ new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
122
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
123
  logger.info(f"Final response length: {len(response)}")
124
  logger.info(f"Response: {response}")
125