SamarthPujari commited on
Commit
19797b1
·
verified ·
1 Parent(s): 64269dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -26,10 +26,8 @@ sd_pipe = StableDiffusionPipeline.from_pretrained(
26
  def get_current_weather(place: str) -> str:
27
  """
28
  Get the current weather for a given location.
29
-
30
  Args:
31
  place (str): Name of the city or location (e.g., "London" or "New York").
32
-
33
  Returns:
34
  str: Weather condition, temperature, humidity, and wind speed.
35
  """
@@ -60,11 +58,9 @@ def get_current_weather(place: str) -> str:
60
  def get_current_time_in_timezone(timezone: str) -> str:
61
  """
62
  Get the current local time in a given timezone.
63
-
64
  Args:
65
  timezone (str): Timezone string in the format 'Region/City',
66
  e.g., "America/New_York".
67
-
68
  Returns:
69
  str: Formatted local time string.
70
  """
@@ -83,11 +79,9 @@ qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")
83
  def document_qna_tool(pdf_path: str, question: str) -> str:
84
  """
85
  Answer a natural language question based on the content of a PDF document.
86
-
87
  Args:
88
  pdf_path (str): Path to the local PDF file.
89
  question (str): Natural language question to ask about the PDF.
90
-
91
  Returns:
92
  str: Answer generated from the most relevant section of the document.
93
  """
@@ -128,7 +122,6 @@ def image_generator(prompt: str) -> str:
128
 
129
  # -------------------- Local LLM (Replaces HfApiModel) --------------------
130
  from transformers import AutoModelForCausalLM, AutoTokenizer
131
- import torch
132
 
133
  class LocalModel:
134
  def __init__(self):
@@ -140,10 +133,25 @@ class LocalModel:
140
  device_map="auto" if torch.cuda.is_available() else None,
141
  )
142
 
143
- def generate(self, prompt, max_new_tokens=500):
 
 
 
 
 
 
144
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
145
  output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
146
- return self.tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
147
 
148
  def __call__(self, prompt, **kwargs):
149
  return self.generate(prompt, **kwargs)
 
26
  def get_current_weather(place: str) -> str:
27
  """
28
  Get the current weather for a given location.
 
29
  Args:
30
  place (str): Name of the city or location (e.g., "London" or "New York").
 
31
  Returns:
32
  str: Weather condition, temperature, humidity, and wind speed.
33
  """
 
58
  def get_current_time_in_timezone(timezone: str) -> str:
59
  """
60
  Get the current local time in a given timezone.
 
61
  Args:
62
  timezone (str): Timezone string in the format 'Region/City',
63
  e.g., "America/New_York".
 
64
  Returns:
65
  str: Formatted local time string.
66
  """
 
79
  def document_qna_tool(pdf_path: str, question: str) -> str:
80
  """
81
  Answer a natural language question based on the content of a PDF document.
 
82
  Args:
83
  pdf_path (str): Path to the local PDF file.
84
  question (str): Natural language question to ask about the PDF.
 
85
  Returns:
86
  str: Answer generated from the most relevant section of the document.
87
  """
 
122
 
123
  # -------------------- Local LLM (Replaces HfApiModel) --------------------
124
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
125
 
126
  class LocalModel:
127
  def __init__(self):
 
133
  device_map="auto" if torch.cuda.is_available() else None,
134
  )
135
 
136
+ def generate(self, prompt, max_new_tokens=500, **kwargs):
137
+ """
138
+ Generate text from the given prompt.
139
+ Extra kwargs like 'stop_sequences' are accepted for compatibility.
140
+ """
141
+ stop_sequences = kwargs.pop("stop_sequences", None)
142
+
143
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
144
  output = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
145
+ text = self.tokenizer.decode(output[0], skip_special_tokens=True)
146
+
147
+ # If stop_sequences provided, truncate output
148
+ if stop_sequences:
149
+ for stop in stop_sequences:
150
+ if stop in text:
151
+ text = text.split(stop)[0]
152
+ break
153
+
154
+ return text
155
 
156
  def __call__(self, prompt, **kwargs):
157
  return self.generate(prompt, **kwargs)