yplam commited on
Commit
ca3ab6d
·
1 Parent(s): 5c4d92a

fix proxy and agent output

Browse files
Files changed (4) hide show
  1. .env.template +6 -0
  2. agent.py +30 -6
  3. app.py +1 -0
  4. tool/youtube.py +7 -5
.env.template ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # OpenAI API Configuration
2
+ OPENAI_API_KEY=your_openai_api_key_here
3
+ OPENAI_API_BASE=https://api.openai.com/v1
4
+ OPENAI_PROXY=http://127.0.0.1:7899
5
+ PROXY_URL=http://127.0.0.1:7899
6
+ # Add other configuration variables below
agent.py CHANGED
@@ -19,19 +19,22 @@ tools = [
19
  read_file
20
  ]
21
 
22
- llm_with_tools = init_chat_model(
23
  model="gpt-4o",
24
  model_provider="openai",
25
  max_retries=2,
26
  openai_api_base=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
27
  openai_api_key=os.getenv("OPENAI_API_KEY"),
28
  openai_proxy=os.getenv("OPENAI_PROXY"),
29
- ).bind_tools(tools)
 
 
30
 
31
 
32
  class State(TypedDict):
33
  input_file: Optional[str]
34
  messages: Annotated[list[AnyMessage], add_messages]
 
35
 
36
  def should_continue(state: State):
37
  messages = state["messages"]
@@ -40,15 +43,36 @@ def should_continue(state: State):
40
  return "tools"
41
  return END
42
 
43
-
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def agent(state: State):
46
- system_message_content = "You are a helpful assistant that can read files and calling tools to answer questions. You should output results directly, without any additional text or explanation."
 
 
 
 
 
 
 
47
  if state["input_file"]:
48
  system_message_content += f"\nYou are given a file: {state['input_file']}"
49
  system_message = SystemMessage(content=system_message_content)
50
  messages = [system_message] + state["messages"]
51
- return {"messages": [llm_with_tools.invoke(messages)]}
 
52
 
53
 
54
  class Agent:
@@ -77,4 +101,4 @@ class Agent:
77
 
78
  def __call__(self, question: str, file_name: str|None) -> str:
79
  result = self.graph.invoke({"input_file": file_name, "messages": [HumanMessage(content=question)]})
80
- return result["messages"][-1].content
 
19
  read_file
20
  ]
21
 
22
+ llm = init_chat_model(
23
  model="gpt-4o",
24
  model_provider="openai",
25
  max_retries=2,
26
  openai_api_base=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
27
  openai_api_key=os.getenv("OPENAI_API_KEY"),
28
  openai_proxy=os.getenv("OPENAI_PROXY"),
29
+ )
30
+
31
+ llm_with_tools = llm.bind_tools(tools)
32
 
33
 
34
  class State(TypedDict):
35
  input_file: Optional[str]
36
  messages: Annotated[list[AnyMessage], add_messages]
37
+ answer: str
38
 
39
  def should_continue(state: State):
40
  messages = state["messages"]
 
43
  return "tools"
44
  return END
45
 
46
+ def format_answer(last_message: str):
47
+ system_message_content = "You are a general AI assistant. \
48
+ Check the user's answer and validate and format it with the following rules: \
49
+ The output should be in the following format: \
50
+ FINAL ANSWER: [YOUR FINAL ANSWER]. \
51
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. \
52
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. \
53
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \
54
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \
55
+ Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. "
56
+ system_message = SystemMessage(content=system_message_content)
57
+ messages = [system_message] + [last_message]
58
+ answer = llm_with_tools.invoke(messages)
59
+ return answer.content
60
 
61
  def agent(state: State):
62
+ system_message_content = "You are a general AI assistant. I will ask you a question. \
63
+ Report your thoughts, and finish your answer with the following template: \
64
+ FINAL ANSWER: [YOUR FINAL ANSWER]. \
65
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. \
66
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. \
67
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. \
68
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. \
69
+ Your answer should only start with 'FINAL ANSWER: ', then follows with the answer. "
70
  if state["input_file"]:
71
  system_message_content += f"\nYou are given a file: {state['input_file']}"
72
  system_message = SystemMessage(content=system_message_content)
73
  messages = [system_message] + state["messages"]
74
+ answer = llm_with_tools.invoke(messages)
75
+ return {"messages": [answer], "answer": format_answer(answer.content)}
76
 
77
 
78
  class Agent:
 
101
 
102
  def __call__(self, question: str, file_name: str|None) -> str:
103
  result = self.graph.invoke({"input_file": file_name, "messages": [HumanMessage(content=question)]})
104
+ return result["answer"]
app.py CHANGED
@@ -99,6 +99,7 @@ def run_all( username: str|None, submit: bool = True):
99
  print("-"*100)
100
  print(f"Running agent on task {task_id}: {question_text}")
101
  submitted_answer = agent(question_text, "")
 
102
  print(f"Submitted answer: {submitted_answer}")
103
  print("-"*100)
104
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
 
99
  print("-"*100)
100
  print(f"Running agent on task {task_id}: {question_text}")
101
  submitted_answer = agent(question_text, "")
102
+ print("-"*30)
103
  print(f"Submitted answer: {submitted_answer}")
104
  print("-"*100)
105
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
tool/youtube.py CHANGED
@@ -8,12 +8,14 @@ def youtube_transcript(video_id: str) -> str:
8
  """
9
  print(f"Extracting transcript from: {video_id}")
10
  try:
11
- ytt_api = YouTubeTranscriptApi(
12
- proxy_config=GenericProxyConfig(
13
- http_url=os.getenv("PROXY_URL"),
14
- https_url=os.getenv("PROXY_URL"),
 
 
 
15
  )
16
- )
17
  transcript = ytt_api.fetch(video_id)
18
  print(f"Transcript: {transcript}")
19
  return transcript
 
8
  """
9
  print(f"Extracting transcript from: {video_id}")
10
  try:
11
+ ytt_api = YouTubeTranscriptApi()
12
+ if os.getenv("PROXY_URL"):
13
+ ytt_api = YouTubeTranscriptApi(
14
+ proxy_config=GenericProxyConfig(
15
+ http_url=os.getenv("PROXY_URL"),
16
+ https_url=os.getenv("PROXY_URL"),
17
+ )
18
  )
 
19
  transcript = ytt_api.fetch(video_id)
20
  print(f"Transcript: {transcript}")
21
  return transcript