SergeyO7 commited on
Commit
5bbfb2e
·
verified ·
1 Parent(s): bc4f755

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -29
app.py CHANGED
@@ -25,43 +25,29 @@ TOKEN_BUCKET_REFILL_RATE = MAX_MODEL_CALLS_PER_MINUTE / 60.0 # Tokens per secon
25
  storage = MemoryStorage()
26
  token_bucket = Limiter(rate=TOKEN_BUCKET_REFILL_RATE, capacity=TOKEN_BUCKET_CAPACITY, storage=storage)
27
 
28
- async def check_n_load_attach(session: aiohttp.ClientSession, task_id
29
- : str, api_url: str = DEFAULT_API_URL) -> Optional[str]:
30
  file_url = f"{api_url}/files/{task_id}"
31
  try:
32
  async with session.get(file_url, timeout=15) as response:
33
  if response.status == 200:
34
-
35
- # Determine file extension from Content-Type
36
  content_type = str(response.headers.get("Content-Type", "")).lower()
37
- extension = ""
38
- if "image/png" in content_type:
39
- extension = ".png"
40
- elif "image/jpeg" in content_type:
41
- extension = ".jpg"
42
- elif "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" in content_type:
43
- extension = ".xlsx"
44
- elif "audio/mpeg" in content_type:
45
- extension = ".mp3"
46
- elif "application/pdf" in content_type:
47
- extension = ".pdf"
48
- elif "text/x-python" in content_type:
49
- extension = ".py"
 
50
  else:
51
  print(f"Unsupported content type: {content_type} for task {task_id}")
52
  return None
53
-
54
- # Use task_id as the filename to ensure uniqueness
55
- filename = f"{task_id}{extension}"
56
- local_file_path = os.path.join("downloads", filename)
57
- os.makedirs("downloads", exist_ok=True)
58
-
59
- # Save the file
60
- async with aiofiles.open(local_file_path, "wb") as file:
61
- async for chunk in response.content.iter_chunked(8192):
62
- await file.write(chunk)
63
- print(f"File downloaded successfully: {local_file_path}")
64
- return local_file_path
65
  else:
66
  print(f"Failed to download file for task {task_id}: HTTP {response.status}")
67
  return None
@@ -69,6 +55,35 @@ async def check_n_load_attach(session: aiohttp.ClientSession, task_id
69
  print(f"Error downloading attachment for task {task_id}: {str(e)}")
70
  return None
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  async def fetch_questions(session: aiohttp.ClientSession, questions_url: str) -> list:
74
  """Fetch questions asynchronously."""
 
25
  storage = MemoryStorage()
26
  token_bucket = Limiter(rate=TOKEN_BUCKET_REFILL_RATE, capacity=TOKEN_BUCKET_CAPACITY, storage=storage)
27
 
28
+ async def check_n_load_attach(session: aiohttp.ClientSession, task_id: str, question: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> Optional[str]:
 
29
  file_url = f"{api_url}/files/{task_id}"
30
  try:
31
  async with session.get(file_url, timeout=15) as response:
32
  if response.status == 200:
 
 
33
  content_type = str(response.headers.get("Content-Type", "")).lower()
34
+ content = await response.read() # Read the file content
35
+
36
+ # Determine extension based on content_type, content, or question
37
+ extension = await determine_extension(content_type, content, question)
38
+
39
+ if extension:
40
+ filename = f"{task_id}{extension}"
41
+ local_file_path = os.path.join("downloads", filename)
42
+ os.makedirs("downloads", exist_ok=True)
43
+
44
+ async with aiofiles.open(local_file_path, "wb") as file:
45
+ await file.write(content)
46
+ print(f"File downloaded successfully: {local_file_path}")
47
+ return local_file_path
48
  else:
49
  print(f"Unsupported content type: {content_type} for task {task_id}")
50
  return None
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
  print(f"Failed to download file for task {task_id}: HTTP {response.status}")
53
  return None
 
55
  print(f"Error downloading attachment for task {task_id}: {str(e)}")
56
  return None
57
 
58
+ async def determine_extension(content_type: str, content: bytes, question: str) -> Optional[str]:
59
+ # Check if the question mentions Excel
60
+ if "excel" in question.lower():
61
+ # Check for XLS signature
62
+ if content.startswith(b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1'):
63
+ return ".xls"
64
+ # Check for XLSX signature (ZIP archive)
65
+ elif content.startswith(b'\x50\x4B\x03\x04'):
66
+ return ".xlsx"
67
+ else:
68
+ return ".xlsx" # Default to XLSX if unsure
69
+ # Standard MIME type checks
70
+ if "image/png" in content_type:
71
+ return ".png"
72
+ elif "jpeg" in content_type or "jpg" in content_type:
73
+ return ".jpg"
74
+ elif "spreadsheetml.sheet" in content_type:
75
+ return ".xlsx"
76
+ elif "vnd.ms-excel" in content_type:
77
+ return ".xls"
78
+ elif "audio/mpeg" in content_type:
79
+ return ".mp3"
80
+ elif "application/pdf" in content_type:
81
+ return ".pdf"
82
+ elif "text/x-python" in content_type:
83
+ return ".py"
84
+ else:
85
+ return None
86
+
87
 
88
  async def fetch_questions(session: aiohttp.ClientSession, questions_url: str) -> list:
89
  """Fetch questions asynchronously."""