Subha95 commited on
Commit
fcbada7
·
verified ·
1 Parent(s): 4b022c1

Update ai_assistant.py

Browse files
Files changed (1) hide show
  1. ai_assistant.py +30 -39
ai_assistant.py CHANGED
@@ -1,60 +1,48 @@
1
- import os
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- import torch
4
  from langchain_community.tools import WikipediaQueryRun, ArxivQueryRun
5
  from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
6
  from langchain_huggingface import HuggingFacePipeline
7
  from langchain.agents import initialize_agent, AgentType
 
 
 
 
8
 
9
- # ✅ Set your HF token in environment
10
- HF_TOKEN = os.getenv("HF_TOKEN")
11
- if not HF_TOKEN:
12
- raise ValueError("Please set HF_TOKEN in your environment.")
13
 
14
- # ---- Build the agent ----
15
  def build_qa():
16
- # ---- Tools ----
17
- wiki_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=200)
18
- wiki_tool = WikipediaQueryRun(api_wrapper=wiki_wrapper)
19
 
20
  arxiv_wrapper = ArxivAPIWrapper(top_k_results=1, doc_content_chars_max=200)
21
- arxiv_tool = ArxivQueryRun(api_wrapper=arxiv_wrapper)
22
 
23
- tools = [wiki_tool, arxiv_tool]
24
 
25
- # ---- Model ----
26
- model_name = "mistralai/Mistral-7B-Instruct-v0.3" # HF repo
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
28
 
29
- # Load the model in **FP16 if possible** and then apply **dynamic quantization** for CPU
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_name,
32
- device_map="cpu",
33
- )
34
-
35
- # ⚡ Apply dynamic quantization (CPU only)
36
- model = torch.quantization.quantize_dynamic(
37
- model, {torch.nn.Linear}, dtype=torch.qint8
38
- )
39
-
40
- # ---- HuggingFace pipeline ----
41
- llm_pipeline = pipeline(
42
- "text-generation",
43
  model=model,
44
  tokenizer=tokenizer,
 
45
  max_new_tokens=256,
46
  temperature=0.2,
47
  do_sample=False,
48
  top_p=0.9,
49
  repetition_penalty=1.2,
50
- device=-1, # CPU
51
- return_full_text=False,
52
  )
53
 
54
- # ---- Wrap in LangChain HuggingFacePipeline ----
55
- hf_llm = HuggingFacePipeline(pipeline=llm_pipeline)
56
 
57
- # ---- Initialize Agent ----
58
  agent = initialize_agent(
59
  tools=tools,
60
  llm=hf_llm,
@@ -65,8 +53,11 @@ def build_qa():
65
 
66
  return agent
67
 
68
- # ---- Example usage ----
69
- if __name__ == "__main__":
70
- agent = build_qa()
71
- response = agent.invoke({"input": "What is query, key, value in attention mechanism?"})
72
- print("\n🤖 Answer:", response)
 
 
 
 
1
+ # ai_assistant.py
 
 
2
  from langchain_community.tools import WikipediaQueryRun, ArxivQueryRun
3
  from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
4
  from langchain_huggingface import HuggingFacePipeline
5
  from langchain.agents import initialize_agent, AgentType
6
+ import os
7
+ from huggingface_hub import login
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
9
+ import torch
10
 
11
+ # HF token login
12
+ token = os.getenv("HF_TOKEN")
13
+ if token:
14
+ login(token=token)
15
 
16
+ # build agent
17
  def build_qa():
18
+ api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=200)
19
+ wiki = WikipediaQueryRun(api_wrapper=api_wrapper)
 
20
 
21
  arxiv_wrapper = ArxivAPIWrapper(top_k_results=1, doc_content_chars_max=200)
22
+ arxiv = ArxivQueryRun(api_wrapper=arxiv_wrapper)
23
 
24
+ tools = [wiki, arxiv]
25
 
26
+ # Load model
27
+ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForCausalLM.from_pretrained(model_name)
30
 
31
+ llm = TextGenerationPipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  model=model,
33
  tokenizer=tokenizer,
34
+ task="text-generation",
35
  max_new_tokens=256,
36
  temperature=0.2,
37
  do_sample=False,
38
  top_p=0.9,
39
  repetition_penalty=1.2,
40
+ eos_token_id=tokenizer.eos_token_id,
41
+ return_full_text=False
42
  )
43
 
44
+ hf_llm = HuggingFacePipeline(pipeline=llm)
 
45
 
 
46
  agent = initialize_agent(
47
  tools=tools,
48
  llm=hf_llm,
 
53
 
54
  return agent
55
 
56
+ # Define get_response function
57
+ _agent_instance = None
58
+ def get_response(user_input: str):
59
+ global _agent_instance
60
+ if _agent_instance is None:
61
+ _agent_instance = build_qa()
62
+ result = _agent_instance.invoke({"input": user_input})
63
+ return result