Mayankuttam commited on
Commit
f80e6a6
·
verified ·
1 Parent(s): 348ccc5

Update model_pipeline.py

Browse files
Files changed (1) hide show
  1. model_pipeline.py +14 -12
model_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import LlavaForConditionalGeneration, LlavaProcessor, pipeline, AutoProcessor, AutoModelForVision2Seq
2
  from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip
3
  from gtts import gTTS
4
  import torch
@@ -11,11 +11,11 @@ import numpy as np
11
  model_id = "may-ur08/llava-commentary-gen"
12
  processor = AutoProcessor.from_pretrained(model_id)
13
 
 
14
  model = AutoModelForVision2Seq.from_pretrained(
15
  model_id,
16
- device_map="auto",
17
- torch_dtype=torch.float16,
18
- load_in_4bit=True
19
  )
20
 
21
  def run_model_on_video(video_path):
@@ -45,14 +45,16 @@ def run_model_on_video(video_path):
45
 
46
  for i, frame_path in enumerate(frames):
47
  image = Image.open(frame_path).convert("RGB")
48
- prompt = ("<image>\n"
49
- "USER: Analyze this image from a live cricket match.\n"
50
- "Identify two things:\n"
51
- "1. What specific type of cricket shot is being played?\n"
52
- "2. What is the likely outcome?\n"
53
- "Only use proper cricket terminology — avoid any football, baseball, or non-cricket references. "
54
- "Now write a short, exciting cricket-style commentary line as if it's being broadcast on TV.\n"
55
- "ASSISTANT:")
 
 
56
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
57
  output = model.generate(**inputs, max_new_tokens=50)
58
  caption = processor.decode(output[0], skip_special_tokens=True).strip()
 
1
+ from transformers import AutoProcessor, AutoModelForVision2Seq, pipeline
2
  from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip
3
  from gtts import gTTS
4
  import torch
 
11
  model_id = "may-ur08/llava-commentary-gen"
12
  processor = AutoProcessor.from_pretrained(model_id)
13
 
14
+ use_gpu = torch.cuda.is_available()
15
  model = AutoModelForVision2Seq.from_pretrained(
16
  model_id,
17
+ device_map="auto" if use_gpu else None,
18
+ torch_dtype=torch.float16 if use_gpu else torch.float32
 
19
  )
20
 
21
  def run_model_on_video(video_path):
 
45
 
46
  for i, frame_path in enumerate(frames):
47
  image = Image.open(frame_path).convert("RGB")
48
+ prompt = (
49
+ "<image>\n"
50
+ "USER: Analyze this image from a live cricket match.\n"
51
+ "Identify two things:\n"
52
+ "1. What specific type of cricket shot is being played?\n"
53
+ "2. What is the likely outcome?\n"
54
+ "Only use proper cricket terminology avoid any football, baseball, or non-cricket references. "
55
+ "Now write a short, exciting cricket-style commentary line as if it's being broadcast on TV.\n"
56
+ "ASSISTANT:"
57
+ )
58
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
59
  output = model.generate(**inputs, max_new_tokens=50)
60
  caption = processor.decode(output[0], skip_special_tokens=True).strip()