Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from typing import List | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import io | |
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| from huggingface_hub import login | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| api_key = os.getenv("HF_TOKEN") | |
| login(token = api_key) | |
| app = FastAPI() | |
| # Load model and tokenizer | |
| model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True, | |
| attn_implementation='sdpa', torch_dtype=torch.bfloat16) | |
| model = model.eval().cuda() | |
| tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6', trust_remote_code=True) | |
| class FewshotExample(BaseModel): | |
| image: bytes | |
| question: str | |
| answer: str | |
| class PredictRequest(BaseModel): | |
| fewshot_examples: List[FewshotExample] | |
| test_image: bytes | |
| test_question: str | |
| async def predict_with_fewshot( | |
| fewshot_images: List[UploadFile] = File(...), | |
| fewshot_questions: List[str] = Form(...), | |
| fewshot_answers: List[str] = Form(...), | |
| test_image: UploadFile = File(...), | |
| test_question: str = Form(...) | |
| ): | |
| # Validate input lengths | |
| if len(fewshot_images)!= len(fewshot_questions) or len(fewshot_questions)!= len(fewshot_answers): | |
| raise HTTPException(status_code=400, detail="Number of few-shot images, questions, and answers must match.") | |
| msgs = [] | |
| try: | |
| for fs_img, fs_q, fs_a in zip(fewshot_images, fewshot_questions, fewshot_answers): | |
| img_content = await fs_img.read() | |
| img = Image.open(io.BytesIO(img_content)).convert('RGB') | |
| msgs.append({'role': 'user', 'content': [img, fs_q]}) | |
| msgs.append({'role': 'assistant', 'content': [fs_a]}) | |
| # Test example | |
| test_img_content = await test_image.read() | |
| test_img = Image.open(io.BytesIO(test_img_content)).convert('RGB') | |
| msgs.append({'role': 'user', 'content': [test_img, test_question]}) | |
| # Get answer | |
| answer = model.chat( | |
| image=None, | |
| msgs=msgs, | |
| tokenizer=tokenizer | |
| ) | |
| return {"answer": answer} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |