Fred808 commited on
Commit
03901aa
·
verified ·
1 Parent(s): 92cce7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -6,10 +6,11 @@ from PIL import Image
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
 
8
  # ===== CONFIG =====
9
- VIDEO_PATH = "How.mp4" # Set to your local video file
10
- FRAMES_DIR = "extracted"
11
- FPS = 3
12
- DEVICE = "cpu" # Force CPU to avoid NCCL GPU issue
 
13
 
14
  # ===== Ensure Output Directory =====
15
  def ensure_dir(path):
@@ -50,14 +51,23 @@ def extract_frames(video_path, output_dir, fps=3):
50
  return frame_paths
51
 
52
  # ===== Load Florence-2 Base Model =====
53
- print("[INFO] Loading Florence-2-base model on CPU")
54
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
55
- model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True, attn_implementation="eager").to(DEVICE).eval()
 
 
 
 
56
 
57
  # ===== Analyze a Frame =====
58
  def analyze_frame(image_path):
59
  image = Image.open(image_path).convert("RGB")
60
- inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(DEVICE)
 
 
 
 
 
61
  with torch.no_grad():
62
  generated_ids = model.generate(
63
  input_ids=inputs["input_ids"],
@@ -70,18 +80,21 @@ def analyze_frame(image_path):
70
  result = processor.post_process_generation(
71
  generated_text,
72
  task="<MORE_DETAILED_CAPTION>",
73
- image_size=(image.width, image.height)
74
  )
75
  return result["<MORE_DETAILED_CAPTION>"]
76
 
77
  # ===== Main Execution =====
78
  if __name__ == "__main__":
79
  frame_list = extract_frames(VIDEO_PATH, FRAMES_DIR, FPS)
80
- print(f"[INFO] {len(frame_list)} frames extracted.")
 
81
  for idx, frame_path in enumerate(frame_list):
82
  print(f"\n[FRAME {idx+1}] Analyzing: {frame_path}")
83
  caption = analyze_frame(frame_path)
84
  print(f"[RESULT] {caption}")
 
 
85
  import uvicorn
86
- port = int(os.getenv("PORT", 7860)) # Spaces set PORT env var
87
- uvicorn.run("app:app", host="0.0.0.0", port=port)
 
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
 
8
  # ===== CONFIG =====
9
+ VIDEO_PATH = "How.mp4" # Local video file in root
10
+ FRAMES_DIR = "extracted" # Where frames are stored
11
+ FPS = 3 # Frames to extract per second
12
+ DEVICE = "cpu" # Use CPU for compatibility
13
+ RESIZE_DIM = (512, 512) # Resize images to this resolution
14
 
15
  # ===== Ensure Output Directory =====
16
  def ensure_dir(path):
 
51
  return frame_paths
52
 
53
  # ===== Load Florence-2 Base Model =====
54
+ print("[INFO] Loading Florence-2-base model on CPU...")
55
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ "microsoft/Florence-2-base",
58
+ trust_remote_code=True,
59
+ attn_implementation="eager"
60
+ ).to(DEVICE).eval()
61
 
62
  # ===== Analyze a Frame =====
63
  def analyze_frame(image_path):
64
  image = Image.open(image_path).convert("RGB")
65
+ image = image.resize(RESIZE_DIM, Image.BILINEAR) # Resize for speed
66
+ inputs = processor(
67
+ text="<MORE_DETAILED_CAPTION>",
68
+ images=image,
69
+ return_tensors="pt"
70
+ ).to(DEVICE)
71
  with torch.no_grad():
72
  generated_ids = model.generate(
73
  input_ids=inputs["input_ids"],
 
80
  result = processor.post_process_generation(
81
  generated_text,
82
  task="<MORE_DETAILED_CAPTION>",
83
+ image_size=RESIZE_DIM
84
  )
85
  return result["<MORE_DETAILED_CAPTION>"]
86
 
87
  # ===== Main Execution =====
88
  if __name__ == "__main__":
89
  frame_list = extract_frames(VIDEO_PATH, FRAMES_DIR, FPS)
90
+ print(f"[INFO] Extracted {len(frame_list)} frames.")
91
+
92
  for idx, frame_path in enumerate(frame_list):
93
  print(f"\n[FRAME {idx+1}] Analyzing: {frame_path}")
94
  caption = analyze_frame(frame_path)
95
  print(f"[RESULT] {caption}")
96
+
97
+ # Optional: Start a dummy Uvicorn server (if you want to expand into an API later)
98
  import uvicorn
99
+ port = int(os.getenv("PORT", 7860)) # for Gradio Spaces compatibility
100
+ uvicorn.run("main:app", host="0.0.0.0", port=port) if os.getenv("RUN_SERVER") else None