File size: 3,637 Bytes
da1cdd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c28d9da
da1cdd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c28d9da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import tempfile

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Qwen2_5_VLForConditionalGeneration, AutoModel
from qwen_vl_utils import process_vision_info

# Specify the model path or identifier.
MODEL_PATH = "Ananthu01/qwen2.5_vl_finetuned_model"

# Initialize the Qwen2.5 VL model.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    device_map="cpu",  # Ensures the model is loaded on CPU.
    use_safetensors=True
)

# Load the processor.
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True)

# Load the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

# Define generation parameters.
generation_config = GenerationConfig(
    temperature=0.1,          # Adjust temperature as needed.
    top_p=0.8,                # Nucleus sampling probability.
    repetition_penalty=1.05,  # Penalty to avoid repetitive outputs.
    max_new_tokens=1024       # Maximum tokens to generate.
)

# Create FastAPI app instance.
app = FastAPI()

@app.get("/", response_class=HTMLResponse)
async def main():
    """
    GET endpoint that renders an HTML form for the user to upload an image.
    """
    content = """
    <html>
        <head>
            <title>Qwen2.5 VL Image Upload</title>
        </head>
        <body>
            <h2>Upload an Image</h2>
            <form action="/generate" enctype="multipart/form-data" method="post">
                <input name="image_file" type="file" accept="image/*">
                <input type="submit" value="Submit">
            </form>
        </body>
    </html>
    """
    return content

@app.post("/generate")
async def generate_output(image_file: UploadFile = File(...)):
    """
    POST endpoint to generate model output using an uploaded image.
    The text prompt is fixed to "Extract JSON".

    - **image_file**: The image file uploaded by the user.
    """
    # Read the uploaded image.
    image_bytes = await image_file.read()

    # Save the image temporarily.
    with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
        tmp.write(image_bytes)
        tmp_path = tmp.name

    # Construct messages with a hardcoded text instruction.
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": tmp_path},
                {"type": "text", "text": "Extract JSON"},
            ],
        }
    ]

    # Apply the chat template using the processor.
    prompt = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Process multimodal inputs.
    image_inputs, video_inputs = process_vision_info(messages)
    mm_data = {}
    if image_inputs is not None:
        mm_data["image"] = image_inputs

    # Tokenize the prompt.
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to("cpu") for k, v in inputs.items()}

    # Generate output from the model.
    # Note: It is assumed that Qwen2.5 VL’s generate method accepts a multi_modal_data argument.
    generated_ids = model.generate(
        **inputs,
        generation_config=generation_config,
        multi_modal_data=mm_data
    )

    # Decode the generated text.
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Clean up the temporary file.
    os.remove(tmp_path)

    # Return the generated text.
    return generated_text