[FIX] fixes max_len and min_len dynamic.
Browse files
tool.py
CHANGED
|
@@ -17,7 +17,7 @@ class TranscriptSummarizer(Tool):
|
|
| 17 |
|
| 18 |
def __init__(self, *args, **kwargs):
|
| 19 |
super().__init__(*args, **kwargs)
|
| 20 |
-
self.summarizer = pipeline("summarization", model="facebook/
|
| 21 |
self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
|
| 22 |
self.headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
|
| 23 |
|
|
@@ -27,12 +27,30 @@ class TranscriptSummarizer(Tool):
|
|
| 27 |
|
| 28 |
def forward(self, transcript: str) -> str:
|
| 29 |
try:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
image_bytes = self.query({"inputs": image_prompt})
|
| 34 |
image = Image.open(io.BytesIO(image_bytes))
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
image.save(image_url) # Save the image to a file
|
| 37 |
return f"{summary}\n\nImage URL: {image_url}" # Return the file path
|
| 38 |
except Exception as e:
|
|
|
|
| 17 |
|
| 18 |
def __init__(self, *args, **kwargs):
|
| 19 |
super().__init__(*args, **kwargs)
|
| 20 |
+
self.summarizer = pipeline("summarization", model="facebook/usin")
|
| 21 |
self.api_url = "https://api-inference.huggingface.co/models/ZB-Tech/Text-to-Image"
|
| 22 |
self.headers = {"Authorization": f"Bearer {os.getenv('HF_API_KEY')}"}
|
| 23 |
|
|
|
|
| 27 |
|
| 28 |
def forward(self, transcript: str) -> str:
|
| 29 |
try:
|
| 30 |
+
transcript_length = len(transcript)
|
| 31 |
+
|
| 32 |
+
def get_summary_lengths(length):
|
| 33 |
+
if length <= 1000:
|
| 34 |
+
max_length = 300
|
| 35 |
+
min_length = 100
|
| 36 |
+
elif length <= 3000:
|
| 37 |
+
max_length = 750
|
| 38 |
+
min_length = 250
|
| 39 |
+
else:
|
| 40 |
+
max_length = 1500
|
| 41 |
+
min_length = 500
|
| 42 |
+
return max_length, min_length
|
| 43 |
+
|
| 44 |
+
max_length, min_length = get_summary_lengths(transcript_length)
|
| 45 |
+
summary = self.summarizer(transcript, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
|
| 46 |
+
key_entities = summary.split()[:3] # Extract first 3 words as key entities
|
| 47 |
+
image_prompt = f"Generate an image related to: {' '.join(key_entities)}, cartoon style"
|
| 48 |
image_bytes = self.query({"inputs": image_prompt})
|
| 49 |
image = Image.open(io.BytesIO(image_bytes))
|
| 50 |
+
image_folder = "Image"
|
| 51 |
+
if not os.path.exists(image_folder):
|
| 52 |
+
os.makedirs(image_folder)
|
| 53 |
+
image_url = os.path.join(image_folder, "image.jpg") # Specify the folder path
|
| 54 |
image.save(image_url) # Save the image to a file
|
| 55 |
return f"{summary}\n\nImage URL: {image_url}" # Return the file path
|
| 56 |
except Exception as e:
|