MYousafRana commited on
Commit
da5c620
·
verified ·
1 Parent(s): 542d360

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -97
main.py DELETED
@@ -1,97 +0,0 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse
3
- import traceback
4
- import tempfile
5
- import torch
6
- import mimetypes
7
- from PIL import Image
8
- import av
9
- import numpy as np
10
-
11
- from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
12
- from my_lib.preproces_video import read_video_pyav
13
-
14
- app = FastAPI()
15
-
16
- # Load model and processor
17
- MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf"
18
-
19
- print("Loading model and processor...")
20
- processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID)
21
- model = LlavaNextVideoForConditionalGeneration.from_pretrained(
22
- MODEL_ID,
23
- torch_dtype=torch.float16,
24
- low_cpu_mem_usage=True,
25
- ).to("cuda" if torch.cuda.is_available() else "cpu")
26
- print("Model and processor loaded.")
27
-
28
- @app.get("/")
29
- async def root():
30
- return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."}
31
-
32
- @app.post("/summarize")
33
- async def summarize_media(file: UploadFile = File(...)):
34
- try:
35
- with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
36
- tmp.write(await file.read())
37
- tmp_path = tmp.name
38
-
39
- content_type = file.content_type
40
- is_video = content_type.startswith("video/")
41
- is_image = content_type.startswith("image/")
42
-
43
- if not (is_video or is_image):
44
- return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"})
45
-
46
- # Define conversation and prompt
47
- if is_video:
48
- container = av.open(tmp_path)
49
- total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0))
50
- container = av.open(tmp_path) # reopen to reset position
51
-
52
- if total_frames == 0:
53
- raise ValueError("Could not extract frames: total frame count is zero.")
54
-
55
- num_frames = min(8, total_frames)
56
- indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
57
- clip = read_video_pyav(container, indices)
58
-
59
- conversation = [
60
- {
61
- "role": "user",
62
- "content": [
63
- {"type": "text", "text": "Summarize this video and explain the key highlights."},
64
- {"type": "video"},
65
- ],
66
- },
67
- ]
68
- prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
69
- inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(model.device)
70
-
71
- elif is_image:
72
- image = Image.open(tmp_path).convert("RGB")
73
- conversation = [
74
- {
75
- "role": "user",
76
- "content": [
77
- {"type": "text", "text": "Describe the image and summarize its content."},
78
- {"type": "image"},
79
- ],
80
- },
81
- ]
82
- prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
83
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
84
-
85
- else:
86
- return JSONResponse(status_code=400, content={"error": "Unsupported media format."})
87
-
88
- # Generate output
89
- output_ids = model.generate(**inputs, max_new_tokens=512)
90
- response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
91
-
92
- return JSONResponse(content={"summary": response_text})
93
-
94
- except Exception as e:
95
- print("Unhandled error:", e)
96
- print(traceback.format_exc())
97
- return JSONResponse(status_code=500, content={"error": str(e)})