disLodge commited on
Commit
ec2661f
·
1 Parent(s): 56b2056
Files changed (2) hide show
  1. app.py +1 -3
  2. indexer.py +11 -12
app.py CHANGED
@@ -2,9 +2,7 @@ import gradio as gr
2
  from indexer import answer_query
3
 
4
  def rag_system(input_text):
5
- answer = answer_query(input_text)
6
-
7
- return answer
8
 
9
  iface = gr.Interface(
10
  fn=rag_system,
 
2
  from indexer import answer_query
3
 
4
  def rag_system(input_text):
5
+ return answer_query(input_text)
 
 
6
 
7
  iface = gr.Interface(
8
  fn=rag_system,
indexer.py CHANGED
@@ -3,11 +3,12 @@ from typing_extensions import TypedDict
3
  from langgraph.graph import StateGraph,START,END
4
  from langgraph.graph.message import add_messages
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
- from langchain.llms import HuggingFacePipeline
 
 
7
  import torch
8
  import os
9
- from dotenv import load_dotenv
10
- from langchain_core.messages import BaseMessage, HumanMessage
11
 
12
 
13
  load_dotenv()
@@ -17,15 +18,14 @@ api = os.getenv("HF_TOKEN")
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(
19
  base_model,
20
- trust_remote_code=True,
21
- token=api,
22
  cache_dir=local_dir,
23
  )
24
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
  base_model,
 
27
  torch_dtype=torch.float16,
28
- token=api,
29
  cache_dir=local_dir,
30
  device_map="auto",
31
  )
@@ -55,17 +55,16 @@ graph_builder=StateGraph(State)
55
  # Node Functionality
56
  def chatbot(state:State):
57
  messages = state["messages"]
58
- if isinstance(messages[-1], BaseMessage):
59
  prompt = messages[-1].content
60
  elif isinstance(messages, str):
61
  prompt = messages
62
  else:
63
  raise ValueError(f"Unsupported message format: {type(messages)}")
64
  response = llm(prompt)
65
- return {"messages":[response]}
66
 
67
 
68
- graph_builder=StateGraph(State)
69
 
70
  # Adding Node
71
  graph_builder.add_node("my_chat",chatbot)
@@ -76,10 +75,10 @@ graph_builder.add_edge("my_chat",END)
76
  graph=graph_builder.compile()
77
 
78
 
79
- def answer_query(query, vectorstore):
80
- response = graph.invoke({"messages": [HumanMessage(content="Hi there")]})
81
 
82
- return response
83
 
84
  def multiply(a:int,b:int)->int:
85
  """
 
3
  from langgraph.graph import StateGraph,START,END
4
  from langgraph.graph.message import add_messages
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ from langchain_community.llms import HuggingFacePipeline
7
+ from dotenv import load_dotenv
8
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
9
  import torch
10
  import os
11
+
 
12
 
13
 
14
  load_dotenv()
 
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(
20
  base_model,
21
+ use_auth_token=api,
 
22
  cache_dir=local_dir,
23
  )
24
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
  base_model,
27
+ use_auth_token=api,
28
  torch_dtype=torch.float16,
 
29
  cache_dir=local_dir,
30
  device_map="auto",
31
  )
 
55
  # Node Functionality
56
  def chatbot(state:State):
57
  messages = state["messages"]
58
+ if isinstance(messages[-1], HumanMessage):
59
  prompt = messages[-1].content
60
  elif isinstance(messages, str):
61
  prompt = messages
62
  else:
63
  raise ValueError(f"Unsupported message format: {type(messages)}")
64
  response = llm(prompt)
65
+ return {"messages":[AIMessage(content=response)]}
66
 
67
 
 
68
 
69
  # Adding Node
70
  graph_builder.add_node("my_chat",chatbot)
 
75
  graph=graph_builder.compile()
76
 
77
 
78
+ def answer_query(query):
79
+ response = graph.invoke({"messages": [HumanMessage(content=query)]})
80
 
81
+ return response["messages"][-1].content
82
 
83
  def multiply(a:int,b:int)->int:
84
  """