vineeths's picture
Update main.py
c28d9da verified
import os
import tempfile
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Qwen2_5_VLForConditionalGeneration, AutoModel
from qwen_vl_utils import process_vision_info
# Specify the model path or identifier.
MODEL_PATH = "Ananthu01/qwen2.5_vl_finetuned_model"
# Initialize the Qwen2.5 VL model.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
device_map="cpu", # Ensures the model is loaded on CPU.
use_safetensors=True
)
# Load the processor.
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True)
# Load the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
# Define generation parameters.
generation_config = GenerationConfig(
temperature=0.1, # Adjust temperature as needed.
top_p=0.8, # Nucleus sampling probability.
repetition_penalty=1.05, # Penalty to avoid repetitive outputs.
max_new_tokens=1024 # Maximum tokens to generate.
)
# Create FastAPI app instance.
app = FastAPI()
@app.get("/", response_class=HTMLResponse)
async def main():
"""
GET endpoint that renders an HTML form for the user to upload an image.
"""
content = """
<html>
<head>
<title>Qwen2.5 VL Image Upload</title>
</head>
<body>
<h2>Upload an Image</h2>
<form action="/generate" enctype="multipart/form-data" method="post">
<input name="image_file" type="file" accept="image/*">
<input type="submit" value="Submit">
</form>
</body>
</html>
"""
return content
@app.post("/generate")
async def generate_output(image_file: UploadFile = File(...)):
"""
POST endpoint to generate model output using an uploaded image.
The text prompt is fixed to "Extract JSON".
- **image_file**: The image file uploaded by the user.
"""
# Read the uploaded image.
image_bytes = await image_file.read()
# Save the image temporarily.
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
tmp.write(image_bytes)
tmp_path = tmp.name
# Construct messages with a hardcoded text instruction.
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": tmp_path},
{"type": "text", "text": "Extract JSON"},
],
}
]
# Apply the chat template using the processor.
prompt = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process multimodal inputs.
image_inputs, video_inputs = process_vision_info(messages)
mm_data = {}
if image_inputs is not None:
mm_data["image"] = image_inputs
# Tokenize the prompt.
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to("cpu") for k, v in inputs.items()}
# Generate output from the model.
# Note: It is assumed that Qwen2.5 VL’s generate method accepts a multi_modal_data argument.
generated_ids = model.generate(
**inputs,
generation_config=generation_config,
multi_modal_data=mm_data
)
# Decode the generated text.
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Clean up the temporary file.
os.remove(tmp_path)
# Return the generated text.
return generated_text