QuantaSparkLabs commited on
Commit
56a3e64
·
verified ·
1 Parent(s): c0caea2

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +20 -19
pipeline.py CHANGED
@@ -1,13 +1,11 @@
1
- import json
2
- import numpy as np
3
- from sentence_transformers import SentenceTransformer
4
  import faiss
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
- import torch
7
 
8
  class JujutsuKaiserver:
9
- def __init__(self, model_dir="./upload_model"):
10
- # Load config
11
  with open(f"{model_dir}/rag_config.json") as f:
12
  config = json.load(f)
13
  self.embedder = SentenceTransformer(config["embedder_model"])
@@ -15,30 +13,33 @@ class JujutsuKaiserver:
15
  with open(f"{model_dir}/chunks.txt", "r", encoding="utf-8") as f:
16
  raw = f.read().split("<|CHUNK_END|>")
17
  self.chunks = [c.strip() for c in raw if c.strip()]
18
- # Load model
19
- bnb_config = BitsAndBytesConfig(
20
- load_in_4bit=True,
21
- bnb_4bit_use_double_quant=True,
22
- bnb_4bit_quant_type='nf4',
23
- bnb_4bit_compute_dtype=torch.float16
24
- )
25
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
26
  self.model = AutoModelForCausalLM.from_pretrained(
27
  model_dir,
28
- quantization_config=bnb_config,
29
  device_map='auto',
30
  trust_remote_code=True
31
  )
 
32
  def ask(self, question, max_tokens=300):
 
 
 
33
  q_emb = self.embedder.encode([question]).astype('float32')
34
- _, indices = self.index.search(q_emb, 5)
35
- context = "\n\n".join([self.chunks[i] for i in indices[0]])
 
 
 
 
 
36
  messages = [
37
- {"role": "system", "content": "You are JujutsuKaiserver, an expert on Jujutsu Kaisen. Answer using ONLY the context. If unsure, say you don't know."},
38
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}
39
  ]
40
  prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
41
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
42
  outputs = self.model.generate(**inputs, max_new_tokens=max_tokens, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
43
  answer = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
44
- return answer.strip()
 
1
+
2
+ import json, torch, numpy as np
3
+ from sentence_transformers import SentenceTransformer, CrossEncoder
4
  import faiss
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
6
 
7
  class JujutsuKaiserver:
8
+ def __init__(self, model_dir="."):
 
9
  with open(f"{model_dir}/rag_config.json") as f:
10
  config = json.load(f)
11
  self.embedder = SentenceTransformer(config["embedder_model"])
 
13
  with open(f"{model_dir}/chunks.txt", "r", encoding="utf-8") as f:
14
  raw = f.read().split("<|CHUNK_END|>")
15
  self.chunks = [c.strip() for c in raw if c.strip()]
16
+ self.reranker = CrossEncoder(f"{model_dir}/cross_encoder_model")
 
 
 
 
 
 
17
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
18
  self.model = AutoModelForCausalLM.from_pretrained(
19
  model_dir,
20
+ torch_dtype=torch.float16,
21
  device_map='auto',
22
  trust_remote_code=True
23
  )
24
+
25
  def ask(self, question, max_tokens=300):
26
+ q_lower = question.strip().lower()
27
+ if q_lower in ('hi', 'hello', 'hey', 'yo', 'sup', 'hi there'):
28
+ return "Hey there! I'm JujutsuKaiserver, your all-knowing JJK assistant. Ask me anything!"
29
  q_emb = self.embedder.encode([question]).astype('float32')
30
+ _, indices = self.index.search(q_emb, 30)
31
+ candidates = [self.chunks[i] for i in indices[0]]
32
+ pairs = [(question, c) for c in candidates]
33
+ scores = self.reranker.predict(pairs)
34
+ reranked = sorted(zip(scores, candidates), reverse=True)[:4]
35
+ best = [c for _, c in reranked]
36
+ context = "\n\n".join(best)
37
  messages = [
38
+ {"role": "system", "content": "You are JujutsuKaiserver, an expert on Jujutsu Kaisen. Answer using ONLY the provided context. Be friendly and concise."},
39
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}
40
  ]
41
  prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
42
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
43
  outputs = self.model.generate(**inputs, max_new_tokens=max_tokens, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)
44
  answer = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
45
+ return answer.strip()