AI_Research_Assistant / ai_assistant.py
Subha95's picture
Update ai_assistant.py
f605107 verified
import os
import traceback
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain_community.tools import WikipediaQueryRun, ArxivQueryRun
from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
from langchain.agents import initialize_agent, AgentType
# βœ… Login to HF Hub
token = os.getenv("HF_TOKEN")
print("πŸ”‘ HF_TOKEN available?", token is not None)
if token:
login(token=token)
else:
print("❌ No HF_TOKEN found in environment")
def build_qa():
print("πŸš€ Starting QA pipeline...")
try:
# ---- TOOLS ----
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=200)
wiki = WikipediaQueryRun(api_wrapper=api_wrapper)
arxiv_wrapper = ArxivAPIWrapper(top_k_results=1, doc_content_chars_max=200)
arxiv = ArxivQueryRun(api_wrapper=arxiv_wrapper)
tools = [wiki, arxiv]
print("πŸ”Ή Tools initialized:", [type(t).__name__ for t in tools])
# ---- MODEL ----
model_name = "mistralai/Mistral-7B-Instruct-v0.2" # HF PyTorch checkpoint
print("πŸ”Ή Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("πŸ”Ή Loading model with 8-bit quantization (CPU)...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # automatically place layers on CPU
load_in_8bit=True, # 8-bit quantization
trust_remote_code=True
)
print("βœ… Model loaded")
# ---- PIPELINE ----
llm_pipeline = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.2,
do_sample=False,
top_p=0.9,
repetition_penalty=1.2,
eos_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
hf_llm = HuggingFacePipeline(pipeline=llm_pipeline)
print("βœ… Pipeline ready")
# ---- AGENT ----
agent = initialize_agent(
tools=tools,
llm=hf_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True,
)
print("βœ… Agent initialized")
return agent
except Exception as e:
print("❌ Failed to build QA pipeline")
traceback.print_exc()
return None
# Build pipeline at import
agent = build_qa()
def get_response(query: str) -> str:
if agent is None:
return "⚠️ QA pipeline not initialized."
try:
return agent.invoke({"input": query})
except Exception as e:
return f"❌ QA run failed: {e}"