esteele commited on
Commit
377d687
·
1 Parent(s): 86baf9d

Switch to gpt2 and blip models

Browse files
app/services/ai_service.py CHANGED
@@ -1,9 +1,14 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
 
3
- MODEL_NAME = "bickett/meme-llama"
 
 
 
 
4
 
5
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
6
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
7
 
8
 
9
  # def generate_meme_caption(image_path: str, max_length: int = 40, num_return_sequences: int = 3):
@@ -45,8 +50,8 @@ def generate_captions(prompt: str, max_length: int = 50, num_return_sequences: i
45
  """
46
  Generate AI meme captions given a prompt using Meme-LLaMA
47
  """
48
- inputs = tokenizer(prompt, return_tensors="pt")
49
- outputs = model.generate(
50
  **inputs,
51
  max_length=max_length,
52
  num_return_sequences=num_return_sequences,
@@ -55,7 +60,7 @@ def generate_captions(prompt: str, max_length: int = 50, num_return_sequences: i
55
  temperature=0.8
56
  )
57
 
58
- captions = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
59
  return captions
60
 
61
  # def generate_captions(prompt: str):
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
 
3
+ # Load GPT-2 small (text generation)
4
+ # -----------------------------
5
+ GPT2_MODEL_NAME = "gpt2"
6
+ gpt2_tokenizer = AutoTokenizer.from_pretrained(GPT2_MODEL_NAME)
7
+ gpt2_model = AutoModelForCausalLM.from_pretrained(GPT2_MODEL_NAME)
8
 
9
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ # model = AutoModelWithLMHead.from_pretrained(MODEL_NAME)
11
+ # model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
12
 
13
 
14
  # def generate_meme_caption(image_path: str, max_length: int = 40, num_return_sequences: int = 3):
 
50
  """
51
  Generate AI meme captions given a prompt using Meme-LLaMA
52
  """
53
+ inputs = gpt2_tokenizer(prompt, return_tensors="pt")
54
+ outputs = gpt2_model.generate(
55
  **inputs,
56
  max_length=max_length,
57
  num_return_sequences=num_return_sequences,
 
60
  temperature=0.8
61
  )
62
 
63
+ captions = [gpt2_tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
64
  return captions
65
 
66
  # def generate_captions(prompt: str):
app/services/image_service.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from transformers import AutoProcessor, BlipForConditionalGeneration
3
+
4
+
5
+ BLIP_MODEL_NAME = "Salesforce/blip-image-captioning-base"
6
+ blip_processor = AutoProcessor.from_pretrained(BLIP_MODEL_NAME)
7
+ blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_MODEL_NAME)
8
+
9
+
10
+ def generate_image_caption(image_path: str) -> str:
11
+ image = Image.open(image_path).convert("RGB")
12
+ inputs = blip_processor(images=image, return_tensors="pt")
13
+ out = blip_model.generate(**inputs)
14
+ caption = blip_processor.decode(out[0], skip_special_tokens=True)
15
+ return caption