Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,43 +25,29 @@ TOKEN_BUCKET_REFILL_RATE = MAX_MODEL_CALLS_PER_MINUTE / 60.0 # Tokens per secon
|
|
| 25 |
storage = MemoryStorage()
|
| 26 |
token_bucket = Limiter(rate=TOKEN_BUCKET_REFILL_RATE, capacity=TOKEN_BUCKET_CAPACITY, storage=storage)
|
| 27 |
|
| 28 |
-
async def check_n_load_attach(session: aiohttp.ClientSession, task_id
|
| 29 |
-
: str, api_url: str = DEFAULT_API_URL) -> Optional[str]:
|
| 30 |
file_url = f"{api_url}/files/{task_id}"
|
| 31 |
try:
|
| 32 |
async with session.get(file_url, timeout=15) as response:
|
| 33 |
if response.status == 200:
|
| 34 |
-
|
| 35 |
-
# Determine file extension from Content-Type
|
| 36 |
content_type = str(response.headers.get("Content-Type", "")).lower()
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
else:
|
| 51 |
print(f"Unsupported content type: {content_type} for task {task_id}")
|
| 52 |
return None
|
| 53 |
-
|
| 54 |
-
# Use task_id as the filename to ensure uniqueness
|
| 55 |
-
filename = f"{task_id}{extension}"
|
| 56 |
-
local_file_path = os.path.join("downloads", filename)
|
| 57 |
-
os.makedirs("downloads", exist_ok=True)
|
| 58 |
-
|
| 59 |
-
# Save the file
|
| 60 |
-
async with aiofiles.open(local_file_path, "wb") as file:
|
| 61 |
-
async for chunk in response.content.iter_chunked(8192):
|
| 62 |
-
await file.write(chunk)
|
| 63 |
-
print(f"File downloaded successfully: {local_file_path}")
|
| 64 |
-
return local_file_path
|
| 65 |
else:
|
| 66 |
print(f"Failed to download file for task {task_id}: HTTP {response.status}")
|
| 67 |
return None
|
|
@@ -69,6 +55,35 @@ async def check_n_load_attach(session: aiohttp.ClientSession, task_id
|
|
| 69 |
print(f"Error downloading attachment for task {task_id}: {str(e)}")
|
| 70 |
return None
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
async def fetch_questions(session: aiohttp.ClientSession, questions_url: str) -> list:
|
| 74 |
"""Fetch questions asynchronously."""
|
|
|
|
| 25 |
storage = MemoryStorage()
|
| 26 |
token_bucket = Limiter(rate=TOKEN_BUCKET_REFILL_RATE, capacity=TOKEN_BUCKET_CAPACITY, storage=storage)
|
| 27 |
|
| 28 |
+
async def check_n_load_attach(session: aiohttp.ClientSession, task_id: str, question: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> Optional[str]:
|
|
|
|
| 29 |
file_url = f"{api_url}/files/{task_id}"
|
| 30 |
try:
|
| 31 |
async with session.get(file_url, timeout=15) as response:
|
| 32 |
if response.status == 200:
|
|
|
|
|
|
|
| 33 |
content_type = str(response.headers.get("Content-Type", "")).lower()
|
| 34 |
+
content = await response.read() # Read the file content
|
| 35 |
+
|
| 36 |
+
# Determine extension based on content_type, content, or question
|
| 37 |
+
extension = await determine_extension(content_type, content, question)
|
| 38 |
+
|
| 39 |
+
if extension:
|
| 40 |
+
filename = f"{task_id}{extension}"
|
| 41 |
+
local_file_path = os.path.join("downloads", filename)
|
| 42 |
+
os.makedirs("downloads", exist_ok=True)
|
| 43 |
+
|
| 44 |
+
async with aiofiles.open(local_file_path, "wb") as file:
|
| 45 |
+
await file.write(content)
|
| 46 |
+
print(f"File downloaded successfully: {local_file_path}")
|
| 47 |
+
return local_file_path
|
| 48 |
else:
|
| 49 |
print(f"Unsupported content type: {content_type} for task {task_id}")
|
| 50 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
else:
|
| 52 |
print(f"Failed to download file for task {task_id}: HTTP {response.status}")
|
| 53 |
return None
|
|
|
|
| 55 |
print(f"Error downloading attachment for task {task_id}: {str(e)}")
|
| 56 |
return None
|
| 57 |
|
| 58 |
+
async def determine_extension(content_type: str, content: bytes, question: str) -> Optional[str]:
|
| 59 |
+
# Check if the question mentions Excel
|
| 60 |
+
if "excel" in question.lower():
|
| 61 |
+
# Check for XLS signature
|
| 62 |
+
if content.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'):
|
| 63 |
+
return ".xls"
|
| 64 |
+
# Check for XLSX signature (ZIP archive)
|
| 65 |
+
elif content.startswith(b'\x50\x4B\x03\x04'):
|
| 66 |
+
return ".xlsx"
|
| 67 |
+
else:
|
| 68 |
+
return ".xlsx" # Default to XLSX if unsure
|
| 69 |
+
# Standard MIME type checks
|
| 70 |
+
if "image/png" in content_type:
|
| 71 |
+
return ".png"
|
| 72 |
+
elif "jpeg" in content_type or "jpg" in content_type:
|
| 73 |
+
return ".jpg"
|
| 74 |
+
elif "spreadsheetml.sheet" in content_type:
|
| 75 |
+
return ".xlsx"
|
| 76 |
+
elif "vnd.ms-excel" in content_type:
|
| 77 |
+
return ".xls"
|
| 78 |
+
elif "audio/mpeg" in content_type:
|
| 79 |
+
return ".mp3"
|
| 80 |
+
elif "application/pdf" in content_type:
|
| 81 |
+
return ".pdf"
|
| 82 |
+
elif "text/x-python" in content_type:
|
| 83 |
+
return ".py"
|
| 84 |
+
else:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
|
| 88 |
async def fetch_questions(session: aiohttp.ClientSession, questions_url: str) -> list:
|
| 89 |
"""Fetch questions asynchronously."""
|