SamarthPujari commited on
Commit
ff4dce1
·
verified ·
1 Parent(s): d9d5d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -90
app.py CHANGED
@@ -1,79 +1,76 @@
1
- from smolagents import CodeAgent, DuckDuckGoSearchTool, load_tool, tool
2
  import datetime
3
  import requests
4
  import pytz
5
  import yaml
6
  import os
7
- from huggingface_hub import login
8
  from tools.final_answer import FinalAnswerTool
9
  from Gradio_UI import GradioUI
10
  import fitz # PyMuPDF
11
  from sentence_transformers import SentenceTransformer, util
12
  from transformers import pipeline
13
- from diffusers import StableDiffusionPipeline
14
- import torch
15
 
16
  # API Key for weather
17
  API_KEY = os.getenv("Weather_Token")
18
 
19
- hf_token = os.getenv("HF_TOKEN")
20
- login(token=hf_token)
21
-
22
- sd_pipe = StableDiffusionPipeline.from_pretrained(
23
- "rupeshs/LCM-runwayml-stable-diffusion-v1-5",
24
- use_auth_token=hf_token,
25
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
26
- ).to(device)
27
-
28
  # -------------------- TOOL 1: Get Weather --------------------
29
  @tool
30
  def get_current_weather(place: str) -> str:
31
  """
32
- Get the current weather for a given location.
33
  Args:
34
- place (str): Name of the city or location (e.g., "London" or "New York").
35
  Returns:
36
- str: Weather condition, temperature, humidity, and wind speed.
37
  """
 
38
  url = "https://api.openweathermap.org/data/2.5/weather"
39
  params = {
40
  "q": place,
41
- "appid": API_KEY,
42
  "units": "metric"
43
  }
 
44
  try:
45
  response = requests.get(url, params=params)
46
  data = response.json()
 
47
  if response.status_code == 200:
 
 
 
 
 
48
  return (
49
  f"Weather in {place}:\n"
50
- f"- Condition: {data['weather'][0]['description']}\n"
51
- f"- Temperature: {data['main']['temp']}°C\n"
52
- f"- Humidity: {data['main']['humidity']}%\n"
53
- f"- Wind Speed: {data['wind']['speed']} m/s"
54
  )
55
  else:
56
- return f"Error: {data.get('message', 'Unknown error')}"
57
  except Exception as e:
58
- return f"Error fetching weather data: {str(e)}"
 
59
 
60
  # -------------------- TOOL 2: Get Time --------------------
61
  @tool
62
  def get_current_time_in_timezone(timezone: str) -> str:
63
  """
64
- Get the current local time in a given timezone.
65
  Args:
66
- timezone (str): Timezone string in the format 'Region/City',
67
- e.g., "America/New_York".
68
  Returns:
69
- str: Formatted local time string.
70
  """
71
  try:
72
  tz = pytz.timezone(timezone)
73
  local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
74
  return f"The current local time in {timezone} is: {local_time}"
75
  except Exception as e:
76
- return f"Error fetching time: {str(e)}"
 
77
 
78
  # -------------------- TOOL 3: Document QnA --------------------
79
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
@@ -82,104 +79,99 @@ qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
82
  @tool
83
  def document_qna_tool(pdf_path: str, question: str) -> str:
84
  """
85
- Answer a natural language question based on the content of a PDF document.
86
  Args:
87
  pdf_path (str): Path to the local PDF file.
88
- question (str): Natural language question to ask about the PDF.
89
  Returns:
90
- str: Answer generated from the most relevant section of the document.
91
  """
 
 
 
 
92
  try:
 
 
 
93
  if not os.path.exists(pdf_path):
94
  return f"[ERROR] File not found: {pdf_path}"
95
- doc = fitz.open(pdf_path)
96
- text_chunks = [page.get_text() for page in doc if page.get_text().strip()]
 
 
 
 
 
 
 
 
 
 
97
  doc.close()
 
98
  if not text_chunks:
99
  return "[ERROR] No readable text in the PDF."
100
 
 
 
 
101
  embeddings = embedding_model.encode(text_chunks, convert_to_tensor=True)
102
  question_embedding = embedding_model.encode(question, convert_to_tensor=True)
 
 
103
  scores = util.pytorch_cos_sim(question_embedding, embeddings)[0]
104
- best_context = text_chunks[scores.argmax().item()]
 
105
 
 
106
  prompt = f"Context: {best_context}\nQuestion: {question}"
 
107
  answer = qa_pipeline(prompt, max_new_tokens=500)[0]['generated_text']
 
108
  return f"Answer: {answer.strip()}"
 
109
  except Exception as e:
110
- return f"[EXCEPTION] {type(e).__name__}: {str(e)}"
111
 
112
- # -------------------- TOOL 4: Local Image Generation --------------------
113
- @tool
114
- def image_generator(prompt: str) -> str:
115
- """
116
- Generate an image from a given text prompt using Stable Diffusion.
117
- Args:
118
- prompt (str): Description of the image to generate.
119
- Returns:
120
- str: Path to the saved generated image.
121
- """
122
- image = sd_pipe(prompt).images[0]
123
- output_path = "generated_image.png"
124
- image.save(output_path)
125
- return f"Image saved at {output_path}"
126
-
127
- # -------------------- Local LLM (Replaces HfApiModel) --------------------
128
- from transformers import AutoModelForCausalLM, AutoTokenizer
129
-
130
- class LocalModel:
131
- def __init__(self):
132
- model_name = "openlm-research/open_llama_3b"
133
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
134
- self.model = AutoModelForCausalLM.from_pretrained(
135
- model_name,
136
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
137
- device_map="auto" if torch.cuda.is_available() else None,
138
- )
139
-
140
- def generate(self, prompt, max_new_tokens=500, **kwargs):
141
- """
142
- Generate text from the given prompt.
143
- Extra kwargs like 'stop_sequences' are accepted for compatibility.
144
- """
145
- stop_sequences = kwargs.pop("stop_sequences", None)
146
-
147
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
148
- output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
149
- text = self.tokenizer.decode(output[0], skip_special_tokens=True)
150
-
151
- # If stop_sequences provided, truncate output
152
- if stop_sequences:
153
- for stop in stop_sequences:
154
- if stop in text:
155
- text = text.split(stop)[0]
156
- break
157
-
158
- return text
159
-
160
- def __call__(self, prompt, **kwargs):
161
- return self.generate(prompt, **kwargs)
162
-
163
- # -------------------- Agent Setup --------------------
164
  final_answer = FinalAnswerTool()
165
  search_tool = DuckDuckGoSearchTool()
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  with open("prompts.yaml", 'r') as stream:
168
  prompt_templates = yaml.safe_load(stream)
169
 
170
- model = LocalModel()
171
  agent = CodeAgent(
172
  model=model,
173
  tools=[
174
  get_current_time_in_timezone,
175
  get_current_weather,
176
- image_generator,
177
  search_tool,
178
- document_qna_tool,
179
  final_answer
180
  ],
181
  max_steps=6,
182
  verbosity_level=1,
 
 
 
 
183
  prompt_templates=prompt_templates
184
  )
185
 
 
1
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
2
  import datetime
3
  import requests
4
  import pytz
5
  import yaml
6
  import os
 
7
  from tools.final_answer import FinalAnswerTool
8
  from Gradio_UI import GradioUI
9
  import fitz # PyMuPDF
10
  from sentence_transformers import SentenceTransformer, util
11
  from transformers import pipeline
 
 
12
 
13
  # API Key for weather
14
  API_KEY = os.getenv("Weather_Token")
15
 
 
 
 
 
 
 
 
 
 
16
  # -------------------- TOOL 1: Get Weather --------------------
17
  @tool
18
  def get_current_weather(place: str) -> str:
19
  """
20
+ A tool that fetches the current weather of a particular place.
21
  Args:
22
+ place (str): A string representing a valid place (e.g., 'London/Paris').
23
  Returns:
24
+ str: Weather description including condition, temperature, humidity, and wind speed.
25
  """
26
+ api_key = API_KEY
27
  url = "https://api.openweathermap.org/data/2.5/weather"
28
  params = {
29
  "q": place,
30
+ "appid": api_key,
31
  "units": "metric"
32
  }
33
+
34
  try:
35
  response = requests.get(url, params=params)
36
  data = response.json()
37
+
38
  if response.status_code == 200:
39
+ weather_desc = data["weather"][0]["description"]
40
+ temperature = data["main"]["temp"]
41
+ humidity = data["main"]["humidity"]
42
+ wind_speed = data["wind"]["speed"]
43
+
44
  return (
45
  f"Weather in {place}:\n"
46
+ f"- Condition: {weather_desc}\n"
47
+ f"- Temperature: {temperature}°C\n"
48
+ f"- Humidity: {humidity}%\n"
49
+ f"- Wind Speed: {wind_speed} m/s"
50
  )
51
  else:
52
+ return f"Error: {data['message']}"
53
  except Exception as e:
54
+ return f"Error fetching weather data for '{place}': {str(e)}"
55
+
56
 
57
  # -------------------- TOOL 2: Get Time --------------------
58
  @tool
59
  def get_current_time_in_timezone(timezone: str) -> str:
60
  """
61
+ A tool that fetches the current local time in a specified timezone.
62
  Args:
63
+ timezone (str): A string representing a valid timezone (e.g., 'America/New_York').
 
64
  Returns:
65
+ str: The current local time formatted as a string.
66
  """
67
  try:
68
  tz = pytz.timezone(timezone)
69
  local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
70
  return f"The current local time in {timezone} is: {local_time}"
71
  except Exception as e:
72
+ return f"Error fetching time for timezone '{timezone}': {str(e)}"
73
+
74
 
75
  # -------------------- TOOL 3: Document QnA --------------------
76
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
79
  @tool
80
  def document_qna_tool(pdf_path: str, question: str) -> str:
81
  """
82
+ A tool that answers natural language questions about a given PDF document.
83
  Args:
84
  pdf_path (str): Path to the local PDF file.
85
+ question (str): Question about the content of the PDF.
86
  Returns:
87
+ str: Answer to the question based on the content.
88
  """
89
+ import os, fitz, traceback
90
+ from sentence_transformers import SentenceTransformer, util
91
+ from transformers import pipeline
92
+
93
  try:
94
+ print(f"[DEBUG] PDF Path: {pdf_path}")
95
+ print(f"[DEBUG] Question: {question}")
96
+
97
  if not os.path.exists(pdf_path):
98
  return f"[ERROR] File not found: {pdf_path}"
99
+
100
+ print("[DEBUG] Opening PDF...")
101
+ try:
102
+ doc = fitz.open(pdf_path)
103
+ except RuntimeError as e:
104
+ return f"[ERROR] Could not open PDF. It may be corrupted or encrypted. Details: {str(e)}"
105
+
106
+ text_chunks = []
107
+ for page in doc:
108
+ text = page.get_text()
109
+ if text.strip():
110
+ text_chunks.append(text)
111
  doc.close()
112
+
113
  if not text_chunks:
114
  return "[ERROR] No readable text in the PDF."
115
 
116
+ print(f"[DEBUG] Extracted {len(text_chunks)} text chunks.")
117
+
118
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
119
  embeddings = embedding_model.encode(text_chunks, convert_to_tensor=True)
120
  question_embedding = embedding_model.encode(question, convert_to_tensor=True)
121
+
122
+ print("[DEBUG] Performing semantic search...")
123
  scores = util.pytorch_cos_sim(question_embedding, embeddings)[0]
124
+ best_match_idx = scores.argmax().item()
125
+ best_context = text_chunks[best_match_idx]
126
 
127
+ qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
128
  prompt = f"Context: {best_context}\nQuestion: {question}"
129
+ print("[DEBUG] Calling QA model...")
130
  answer = qa_pipeline(prompt, max_new_tokens=500)[0]['generated_text']
131
+
132
  return f"Answer: {answer.strip()}"
133
+
134
  except Exception as e:
135
+ return f"[EXCEPTION] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
136
 
137
+ # -------------------- Other Components --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  final_answer = FinalAnswerTool()
139
  search_tool = DuckDuckGoSearchTool()
140
 
141
+ model = HfApiModel(
142
+ max_tokens=2096,
143
+ temperature=0.5,
144
+ model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
145
+ custom_role_conversions=None,
146
+ )
147
+
148
+ from smolagents import Tool
149
+
150
+ image_generation_tool = Tool.from_space(
151
+ "black-forest-labs/FLUX.1-schnell",
152
+ name="image_generator", # You can name it whatever makes sense for your agent
153
+ description="Generate an image from a prompt"
154
+ )
155
+
156
  with open("prompts.yaml", 'r') as stream:
157
  prompt_templates = yaml.safe_load(stream)
158
 
 
159
  agent = CodeAgent(
160
  model=model,
161
  tools=[
162
  get_current_time_in_timezone,
163
  get_current_weather,
164
+ image_generation_tool,
165
  search_tool,
166
+ document_qna_tool, # ← New Tool Added
167
  final_answer
168
  ],
169
  max_steps=6,
170
  verbosity_level=1,
171
+ grammar=None,
172
+ planning_interval=None,
173
+ name=None,
174
+ description=None,
175
  prompt_templates=prompt_templates
176
  )
177