MindLabUnimib commited on
Commit
cc0f36f
·
verified ·
1 Parent(s): 668feeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -30
app.py CHANGED
@@ -17,31 +17,6 @@ from transformers import (
17
 
18
  print("\n=== Environment Setup ===")
19
 
20
- if torch.cuda.is_available():
21
- print(f"GPU detected: {torch.cuda.get_device_name(0)}")
22
- try:
23
- subprocess.run(
24
- "pip install flash-attn --no-build-isolation",
25
- shell=True,
26
- check=True,
27
- )
28
- print("✅ flash-attn installed successfully")
29
- except subprocess.CalledProcessError as e:
30
- print("⚠️ flash-attn installation failed:", e)
31
- else:
32
- print("⚙️ CPU detected — skipping flash-attn installation")
33
- # Disable flash-attn references safely
34
- os.environ["DISABLE_FLASH_ATTN"] = "1"
35
- os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
36
- try:
37
- from transformers.utils import import_utils
38
-
39
- if "flash_attn" not in import_utils.PACKAGE_DISTRIBUTION_MAPPING: # type: ignore
40
- print(import_utils.PACKAGE_DISTRIBUTION_MAPPING)
41
- import_utils.PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] = "flash-attn" # type: ignore
42
- except Exception as e:
43
- print("⚠️ Patch skipped:", e)
44
-
45
  if torch.cuda.is_available():
46
  device = torch.device("cuda")
47
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
@@ -50,11 +25,6 @@ else:
50
  print("Using CPU")
51
 
52
  print("\n=== Model Loading ===")
53
- import torch
54
- import transformers
55
-
56
- from transformers.utils.import_utils import is_flash_attn_2_available
57
- print("is_flash_attn_2_available: ", is_flash_attn_2_available())
58
 
59
  chat_model_name = "sapienzanlp/Minerva-7B-instruct-v1.0"
60
  cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
 
17
 
18
  print("\n=== Environment Setup ===")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if torch.cuda.is_available():
21
  device = torch.device("cuda")
22
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
 
25
  print("Using CPU")
26
 
27
  print("\n=== Model Loading ===")
 
 
 
 
 
28
 
29
  chat_model_name = "sapienzanlp/Minerva-7B-instruct-v1.0"
30
  cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"