llm-deepfake / app.py
drg31's picture
Update app.py
a1bd28c verified
from fastapi import FastAPI, Request
import base64
import io
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
# Load Image Captioning Model (BLIP)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Load LLM (GPT-2 as an example)
llm_model = AutoModelForCausalLM.from_pretrained("gpt2")
llm_tokenizer = AutoTokenizer.from_pretrained("gpt2")
@app.post("/process")
async def process(request: Request):
data = await request.json()
image_base64 = data["image"]
gradcam_base64 = data["gradcam"]
prediction = data["prediction"]
# Decode Images
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
gradcam = Image.open(io.BytesIO(base64.b64decode(gradcam_base64)))
# Generate Caption from GradCAM
inputs = processor(gradcam, return_tensors="pt")
out = caption_model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
# Prepare Input for LLM
llm_input = f"The image is predicted as {prediction}. Caption: {caption}. Explain why this might be the case."
# Generate LLM Response
inputs = llm_tokenizer(llm_input, return_tensors="pt")
outputs = llm_model.generate(**inputs, max_length=100)
llm_result = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"result": llm_result}