MYousafRana commited on
Commit
309f212
·
verified ·
1 Parent(s): 245a8c3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -116
main.py CHANGED
@@ -1,116 +1,126 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse
3
- import traceback
4
- import tempfile
5
- import torch
6
- # import mimetypes
7
- from PIL import Image
8
- import av
9
- import numpy as np
10
- import os
11
-
12
- from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
13
- from my_lib.preproces_video import read_video_pyav
14
-
15
- app = FastAPI()
16
-
17
- # Load model and processor
18
- MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- print("Loading model and processor...")
22
- processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)
23
-
24
- # Optional: Pre-cache model on HF Spaces to avoid redownloading
25
- # from huggingface_hub import snapshot_download
26
- # snapshot_download(MODEL_ID)
27
-
28
- if device.type == "cuda":
29
- model = LlavaNextVideoForConditionalGeneration.from_pretrained(
30
- MODEL_ID,
31
- torch_dtype=torch.float16,
32
- low_cpu_mem_usage=True,
33
- load_in_4bit=True
34
- ).to(device)
35
- else:
36
- model = LlavaNextVideoForConditionalGeneration.from_pretrained(
37
- MODEL_ID,
38
- torch_dtype=torch.float32
39
- ).to(device)
40
-
41
- print(f"Model and processor loaded on {device}.")
42
-
43
- @app.get("/")
44
- async def root():
45
- return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."}
46
-
47
- @app.get("/health")
48
- async def health():
49
- return {"status": "ok", "device": device.type}
50
-
51
- @app.post("/summarize")
52
- async def summarize_media(file: UploadFile = File(...)):
53
- try:
54
- with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
55
- tmp.write(await file.read())
56
- tmp_path = tmp.name
57
-
58
- content_type = file.content_type
59
- is_video = content_type.startswith("video/")
60
- is_image = content_type.startswith("image/")
61
-
62
- if not (is_video or is_image):
63
- os.unlink(tmp_path)
64
- return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"})
65
-
66
- if is_video:
67
- container = av.open(tmp_path)
68
- total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0))
69
- container = av.open(tmp_path) # reopen to reset position
70
-
71
- if total_frames == 0:
72
- raise ValueError("Could not extract frames: total frame count is zero.")
73
-
74
- num_frames = min(8, total_frames)
75
- indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
76
- clip = read_video_pyav(container, indices)
77
-
78
- conversation = [
79
- {
80
- "role": "user",
81
- "content": [
82
- {"type": "text", "text": "Summarize this video and explain the key highlights."},
83
- {"type": "video"},
84
- ],
85
- },
86
- ]
87
- prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
88
- inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device)
89
-
90
- elif is_image:
91
- image = Image.open(tmp_path).convert("RGB")
92
- conversation = [
93
- {
94
- "role": "user",
95
- "content": [
96
- {"type": "text", "text": "Describe the image and summarize its content."},
97
- {"type": "image"},
98
- ],
99
- },
100
- ]
101
- prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
102
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
103
-
104
- output_ids = model.generate(**inputs, max_new_tokens=512)
105
- response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
106
-
107
- return JSONResponse(content={"summary": response_text})
108
-
109
- except Exception as e:
110
- print("Unhandled error:", e)
111
- print(traceback.format_exc())
112
- return JSONResponse(status_code=500, content={"error": str(e)})
113
-
114
- finally:
115
- if 'tmp_path' in locals() and os.path.exists(tmp_path):
116
- os.unlink(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import JSONResponse
3
+ import traceback
4
+ import tempfile
5
+ import torch
6
+ # import mimetypes
7
+ from PIL import Image
8
+ import av
9
+ import numpy as np
10
+ import os
11
+
12
+ from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
13
+ from my_lib.preproces_video import read_video_pyav
14
+
15
+ app = FastAPI()
16
+
17
+ # Load model and processor
18
+ MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ print("Loading model and processor...")
22
+ processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)
23
+
24
+ # Optional: Pre-cache model on HF Spaces to avoid redownloading
25
+ # from huggingface_hub import snapshot_download
26
+ # snapshot_download(MODEL_ID)
27
+
28
+ if device.type == "cuda":
29
+ try:
30
+ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
31
+ MODEL_ID,
32
+ torch_dtype=torch.float16,
33
+ low_cpu_mem_usage=True,
34
+ load_in_4bit=True # Requires bitsandbytes and GPU
35
+ ).to(device)
36
+ print("Loaded model in 4-bit quantized mode.")
37
+ except Exception as e:
38
+ print("Failed to load in 4-bit mode:", e)
39
+ print("Falling back to full precision FP16.")
40
+ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
41
+ MODEL_ID,
42
+ torch_dtype=torch.float16,
43
+ low_cpu_mem_usage=True,
44
+ ).to(device)
45
+ else:
46
+ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
47
+ MODEL_ID,
48
+ torch_dtype=torch.float32
49
+ ).to(device)
50
+
51
+ print(f"Model and processor loaded on {device}.")
52
+
53
+ @app.get("/")
54
+ async def root():
55
+ return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."}
56
+
57
+ @app.get("/health")
58
+ async def health():
59
+ return {"status": "ok", "device": device.type}
60
+
61
+ @app.post("/summarize")
62
+ async def summarize_media(file: UploadFile = File(...)):
63
+ try:
64
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
65
+ tmp.write(await file.read())
66
+ tmp_path = tmp.name
67
+
68
+ content_type = file.content_type
69
+ is_video = content_type.startswith("video/")
70
+ is_image = content_type.startswith("image/")
71
+
72
+ if not (is_video or is_image):
73
+ os.unlink(tmp_path)
74
+ return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"})
75
+
76
+ if is_video:
77
+ container = av.open(tmp_path)
78
+ total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0))
79
+ container = av.open(tmp_path) # reopen to reset position
80
+
81
+ if total_frames == 0:
82
+ raise ValueError("Could not extract frames: total frame count is zero.")
83
+
84
+ num_frames = min(8, total_frames)
85
+ indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
86
+ clip = read_video_pyav(container, indices)
87
+
88
+ conversation = [
89
+ {
90
+ "role": "user",
91
+ "content": [
92
+ {"type": "text", "text": "Summarize this video and explain the key highlights."},
93
+ {"type": "video"},
94
+ ],
95
+ },
96
+ ]
97
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
98
+ inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device)
99
+
100
+ elif is_image:
101
+ image = Image.open(tmp_path).convert("RGB")
102
+ conversation = [
103
+ {
104
+ "role": "user",
105
+ "content": [
106
+ {"type": "text", "text": "Describe the image and summarize its content."},
107
+ {"type": "image"},
108
+ ],
109
+ },
110
+ ]
111
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
112
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
113
+
114
+ output_ids = model.generate(**inputs, max_new_tokens=512)
115
+ response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
116
+
117
+ return JSONResponse(content={"summary": response_text})
118
+
119
+ except Exception as e:
120
+ print("Unhandled error:", e)
121
+ print(traceback.format_exc())
122
+ return JSONResponse(status_code=500, content={"error": str(e)})
123
+
124
+ finally:
125
+ if 'tmp_path' in locals() and os.path.exists(tmp_path):
126
+ os.unlink(tmp_path)