Mountchicken commited on
Commit
1c5536c
·
verified ·
1 Parent(s): 0efd84f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -6,16 +6,26 @@ import json
6
  import os
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  import subprocess
11
  subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
- os.system("pip install transformers==4.51.1")
13
- os.system("pip install numpy==1.26.4")
14
- os.system("pip install vllm==0.8.2")
15
- os.system("pip install torch==2.4.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu124")
16
 
17
  import sys
18
  import threading
 
 
19
  import re
20
  from typing import Any, Dict, List
21
 
@@ -29,7 +39,7 @@ from rex_omni.tasks import KEYPOINT_CONFIGS, TASK_CONFIGS, get_task_config
29
 
30
 
31
  def parse_args():
32
- parser = argparse.ArgumentParser(description="Rex-Omni Gradio Demo")
33
  parser.add_argument(
34
  "--model_path",
35
  default="IDEA-Research/Rex-Omni",
@@ -45,7 +55,7 @@ def parse_args():
45
  parser.add_argument("--temperature", type=float, default=0.0)
46
  parser.add_argument("--top_p", type=float, default=0.05)
47
  parser.add_argument("--top_k", type=int, default=1)
48
- parser.add_argument("--max_tokens", type=int, default=4096)
49
  parser.add_argument("--repetition_penalty", type=float, default=1.05)
50
  parser.add_argument("--min_pixels", type=int, default=16 * 28 * 28)
51
  parser.add_argument("--max_pixels", type=int, default=2560 * 28 * 28)
@@ -942,4 +952,4 @@ if __name__ == "__main__":
942
  server_port=args.server_port,
943
  share=True,
944
  debug=True,
945
- )
 
6
  import os
7
 
8
 
9
+ os.system("pip install matplotlib==3.10.6")
10
+
11
+ os.system("pip install Pillow==11.3.0")
12
+ os.system("pip install qwen_vl_utils==0.0.14")
13
+ os.system("pip install transformers==4.51.3")
14
+ # os.system("pip install vllm==0.8.2")
15
+ os.system("pip install accelerate==1.10.1")
16
+
17
+ os.system("pip install gradio==4.44.1")
18
+ os.system("pip install gradio_image_prompter==0.1.0")
19
+ # os.system("pip install pydantic==2.10.6")
20
+ os.system("pip install torch==2.4.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu124")
21
 
22
  import subprocess
23
  subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
24
 
25
  import sys
26
  import threading
27
+ # os.system("pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124")
28
+ os.system("pip install --no-cache-dir --force-reinstall 'numpy==1.26.4'")
29
  import re
30
  from typing import Any, Dict, List
31
 
 
39
 
40
 
41
  def parse_args():
42
+ parser = argparse.ArgumentParser(description="Rex Omni Gradio Demo")
43
  parser.add_argument(
44
  "--model_path",
45
  default="IDEA-Research/Rex-Omni",
 
55
  parser.add_argument("--temperature", type=float, default=0.0)
56
  parser.add_argument("--top_p", type=float, default=0.05)
57
  parser.add_argument("--top_k", type=int, default=1)
58
+ parser.add_argument("--max_tokens", type=int, default=2048)
59
  parser.add_argument("--repetition_penalty", type=float, default=1.05)
60
  parser.add_argument("--min_pixels", type=int, default=16 * 28 * 28)
61
  parser.add_argument("--max_pixels", type=int, default=2560 * 28 * 28)
 
952
  server_port=args.server_port,
953
  share=True,
954
  debug=True,
955
+ )