from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse import traceback import tempfile import torch # import mimetypes from PIL import Image import av import numpy as np import os from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration from my_lib.preproces_video import read_video_pyav app = FastAPI() # Load model and processor MODEL_ID = "llava-hf/LLaVA-NeXT-Video-7B-hf" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Loading model and processor...") processor = LlavaNextVideoProcessor.from_pretrained(MODEL_ID) # Optional: Pre-cache model on HF Spaces to avoid redownloading # from huggingface_hub import snapshot_download # snapshot_download(MODEL_ID) if device.type == "cuda": try: model = LlavaNextVideoForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float16, low_cpu_mem_usage=True, load_in_4bit=True # Requires bitsandbytes and GPU ).to(device) print("Loaded model in 4-bit quantized mode.") except Exception as e: print("Failed to load in 4-bit mode:", e) print("Falling back to full precision FP16.") model = LlavaNextVideoForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to(device) else: model = LlavaNextVideoForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float32 ).to(device) print(f"Model and processor loaded on {device}.") @app.get("/") async def root(): return {"message": "Welcome to the Summarization API. Use /summarize to summarize media files."} @app.get("/health") async def health(): return {"status": "ok", "device": device.type} @app.post("/summarize") async def summarize_media(file: UploadFile = File(...)): try: with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp: tmp.write(await file.read()) tmp_path = tmp.name content_type = file.content_type is_video = content_type.startswith("video/") is_image = content_type.startswith("image/") if not (is_video or is_image): os.unlink(tmp_path) return JSONResponse(status_code=400, content={"error": f"Unsupported file type: {content_type}"}) if is_video: container = av.open(tmp_path) total_frames = container.streams.video[0].frames or sum(1 for _ in container.decode(video=0)) container = av.open(tmp_path) # reopen to reset position if total_frames == 0: raise ValueError("Could not extract frames: total frame count is zero.") num_frames = min(8, total_frames) indices = np.linspace(0, total_frames - 1, num_frames).astype(int) clip = read_video_pyav(container, indices) conversation = [ { "role": "user", "content": [ {"type": "text", "text": "Summarize this video and explain the key highlights."}, {"type": "video"}, ], }, ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=prompt, videos=clip, return_tensors="pt").to(device) elif is_image: image = Image.open(tmp_path).convert("RGB") conversation = [ { "role": "user", "content": [ {"type": "text", "text": "Describe the image and summarize its content."}, {"type": "image"}, ], }, ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) output_ids = model.generate(**inputs, max_new_tokens=512) response_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0] return JSONResponse(content={"summary": response_text}) except Exception as e: print("Unhandled error:", e) print(traceback.format_exc()) return JSONResponse(status_code=500, content={"error": str(e)}) finally: if 'tmp_path' in locals() and os.path.exists(tmp_path): os.unlink(tmp_path)