Fred808 commited on
Commit
b5dfc9f
·
verified ·
1 Parent(s): 3d0c398

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -63
app.py CHANGED
@@ -1,78 +1,84 @@
1
- import io
2
  import os
 
3
  import torch
 
4
  from PIL import Image
5
- from fastapi import FastAPI, UploadFile, File
6
- from fastapi.responses import JSONResponse
7
  from transformers import AutoProcessor, AutoModelForCausalLM
8
 
9
- # Setup
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- app = FastAPI(title="Florence-2 Base Image Captioning API")
 
 
12
 
13
- # Load Florence-2 base model
14
- try:
15
- vision_model = AutoModelForCausalLM.from_pretrained(
16
- 'microsoft/Florence-2-base',
17
- trust_remote_code=True,
18
- attn_implementation="eager"
19
- ).to(device).eval()
20
 
21
- vision_processor = AutoProcessor.from_pretrained(
22
- 'microsoft/Florence-2-base',
23
- trust_remote_code=True
24
- )
25
- except Exception as e:
26
- vision_model = None
27
- vision_processor = None
28
- print(f"Model loading error: {e}")
29
 
30
- @app.post("/describe-image")
31
- async def describe_image(file: UploadFile = File(...)):
32
- if vision_model is None or vision_processor is None:
33
- return JSONResponse(status_code=500, content={"error": "Model not loaded"})
 
34
 
35
- try:
36
- contents = await file.read()
37
- image = Image.open(io.BytesIO(contents)).convert("RGB")
 
38
 
39
- # Preprocess
40
- inputs = vision_processor(
41
- text="<MORE_DETAILED_CAPTION>",
42
- images=image,
43
- return_tensors="pt"
44
- ).to(device)
 
 
 
 
 
 
 
 
45
 
46
- with torch.no_grad():
47
- generated_ids = vision_model.generate(
48
- input_ids=inputs["input_ids"],
49
- pixel_values=inputs["pixel_values"],
50
- max_new_tokens=1024,
51
- num_beams=3,
52
- )
53
 
54
- generated_text = vision_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
55
- processed = vision_processor.post_process_generation(
56
- generated_text,
57
- task="<MORE_DETAILED_CAPTION>",
58
- image_size=image.size
 
 
 
 
 
 
59
  )
60
- caption = processed["<MORE_DETAILED_CAPTION>"]
61
-
62
- return JSONResponse(content={
63
- "filename": file.filename,
64
- "description": caption
65
- })
66
-
67
- except Exception as e:
68
- return JSONResponse(status_code=500, content={"error": str(e)})
69
-
70
- @app.get("/")
71
- def root():
72
- return {"message": "Florence-2 Base Image Captioning API is running"}
73
 
74
- # Run the app when executed directly
75
  if __name__ == "__main__":
76
- import uvicorn
77
- port = int(os.getenv("PORT", 7860)) # Spaces set PORT env var
78
- uvicorn.run("app:app", host="0.0.0.0", port=port)
 
 
 
 
 
1
  import os
2
+ import cv2
3
  import torch
4
+ from pathlib import Path
5
  from PIL import Image
 
 
6
  from transformers import AutoProcessor, AutoModelForCausalLM
7
 
8
+ # ===== CONFIG =====
9
+ VIDEO_PATH = "How.mp4" # Set to your local video file
10
+ FRAMES_DIR = "extracted"
11
+ FPS = 3
12
+ DEVICE = "cpu" # Force CPU to avoid NCCL GPU issue
13
 
14
+ # ===== Ensure Output Directory =====
15
+ def ensure_dir(path):
16
+ Path(path).mkdir(parents=True, exist_ok=True)
 
 
 
 
17
 
18
+ # ===== Frame Extraction Function =====
19
+ def extract_frames(video_path, output_dir, fps=3):
20
+ ensure_dir(output_dir)
21
+ cap = cv2.VideoCapture(str(video_path))
22
+ if not cap.isOpened():
23
+ print(f"[ERROR] Failed to open video file: {video_path}")
24
+ return []
 
25
 
26
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
27
+ if not video_fps or video_fps <= 0:
28
+ print("[WARN] Using fallback FPS: 30")
29
+ video_fps = 30
30
+ frame_interval = int(round(video_fps / fps))
31
 
32
+ frame_idx = 0
33
+ saved_idx = 1
34
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
35
+ frame_paths = []
36
 
37
+ while cap.isOpened():
38
+ ret, frame = cap.read()
39
+ if not ret:
40
+ break
41
+ if frame_idx % frame_interval == 0:
42
+ frame_name = f"{saved_idx:04d}.png"
43
+ output_path = Path(output_dir) / frame_name
44
+ cv2.imwrite(str(output_path), frame)
45
+ frame_paths.append(str(output_path))
46
+ print(f"[INFO] Saved frame {frame_idx} -> {frame_name}")
47
+ saved_idx += 1
48
+ frame_idx += 1
49
+ cap.release()
50
+ return frame_paths
51
 
52
+ # ===== Load Florence-2 Base Model =====
53
+ print("[INFO] Loading Florence-2-base model on CPU")
54
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
55
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True, attn_implementation="eager").to(DEVICE).eval()
 
 
 
56
 
57
+ # ===== Analyze a Frame =====
58
+ def analyze_frame(image_path):
59
+ image = Image.open(image_path).convert("RGB")
60
+ inputs = processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(DEVICE)
61
+ with torch.no_grad():
62
+ generated_ids = model.generate(
63
+ input_ids=inputs["input_ids"],
64
+ pixel_values=inputs["pixel_values"],
65
+ max_new_tokens=1024,
66
+ num_beams=3,
67
+ do_sample=False
68
  )
69
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
70
+ result = processor.post_process_generation(
71
+ generated_text,
72
+ task="<MORE_DETAILED_CAPTION>",
73
+ image_size=(image.width, image.height)
74
+ )
75
+ return result["<MORE_DETAILED_CAPTION>"]
 
 
 
 
 
 
76
 
77
+ # ===== Main Execution =====
78
  if __name__ == "__main__":
79
+ frame_list = extract_frames(VIDEO_PATH, FRAMES_DIR, FPS)
80
+ print(f"[INFO] {len(frame_list)} frames extracted.")
81
+ for idx, frame_path in enumerate(frame_list):
82
+ print(f"\n[FRAME {idx+1}] Analyzing: {frame_path}")
83
+ caption = analyze_frame(frame_path)
84
+ print(f"[RESULT] {caption}")