Mohansai2004 commited on
Commit
23c9421
·
1 Parent(s): 9a224ef

added the model

Browse files
app/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (196 Bytes). View file
 
app/__pycache__/app.cpython-313.pyc ADDED
Binary file (3.36 kB). View file
 
app/__pycache__/caption_model.cpython-313.pyc ADDED
Binary file (6.83 kB). View file
 
app/__pycache__/model.cpython-313.pyc ADDED
Binary file (6.43 kB). View file
 
app/__pycache__/utils.cpython-313.pyc ADDED
Binary file (1.87 kB). View file
 
app/app.py CHANGED
@@ -1,28 +1,57 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
  from starlette.responses import JSONResponse
 
4
  from app.model import analyze_image
5
  from app.utils import read_image
 
6
 
7
  app = FastAPI(title="Image Analyzer API", version="1.0.0")
8
 
9
- # CORS config
10
- app.add_middleware(
11
- CORSMiddleware,
12
- allow_origins=["*"],
13
- allow_credentials=True,
14
- allow_methods=["*"],
15
- allow_headers=["*"],
16
- )
17
 
18
  @app.post("/analyze")
19
  async def analyze(file: UploadFile = File(...)):
 
 
20
  try:
21
  image = read_image(file)
 
 
 
 
 
 
 
22
  result = analyze_image(image)
23
  return JSONResponse(content=result)
 
 
 
 
24
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  raise HTTPException(status_code=500, detail=str(e))
 
 
26
 
27
  @app.get("/")
28
  def read_root():
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
 
2
  from starlette.responses import JSONResponse
3
+ from starlette.requests import Request
4
  from app.model import analyze_image
5
  from app.utils import read_image
6
+ from app.caption_model import captioner
7
 
8
  app = FastAPI(title="Image Analyzer API", version="1.0.0")
9
 
10
+
 
 
 
 
 
 
 
11
 
12
  @app.post("/analyze")
13
  async def analyze(file: UploadFile = File(...)):
14
+ if not file or not file.filename:
15
+ raise HTTPException(status_code=400, detail="No file uploaded.")
16
  try:
17
  image = read_image(file)
18
+ except Exception as e:
19
+ raise HTTPException(status_code=400, detail=f"Failed to read image: {str(e)}")
20
+
21
+ if not file.content_type.startswith('image/'):
22
+ raise HTTPException(status_code=400, detail="File must be an image")
23
+
24
+ try:
25
  result = analyze_image(image)
26
  return JSONResponse(content=result)
27
+ except ValueError as e:
28
+ raise HTTPException(status_code=400, detail=str(e))
29
+ except RuntimeError as e:
30
+ raise HTTPException(status_code=500, detail=str(e))
31
  except Exception as e:
32
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
33
+
34
+ @app.post("/caption")
35
+ async def generate_caption(file: UploadFile = File(...)):
36
+ if not file or not file.filename:
37
+ raise HTTPException(status_code=400, detail="No file uploaded.")
38
+ try:
39
+ image = read_image(file)
40
+ except Exception as e:
41
+ raise HTTPException(status_code=400, detail=f"Failed to read image: {str(e)}")
42
+
43
+ if not file.content_type.startswith('image/'):
44
+ raise HTTPException(status_code=400, detail="File must be an image")
45
+
46
+ try:
47
+ result = captioner.generate_caption(image)
48
+ return JSONResponse(content=result)
49
+ except ValueError as e:
50
+ raise HTTPException(status_code=400, detail=str(e))
51
+ except RuntimeError as e:
52
  raise HTTPException(status_code=500, detail=str(e))
53
+ except Exception as e:
54
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
55
 
56
  @app.get("/")
57
  def read_root():
app/caption_model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BlipProcessor, BlipForConditionalGeneration
2
+ import torch
3
+ from PIL import Image
4
+ import logging
5
+ import time
6
+ from typing import Dict, Any, Optional
7
+ import gc
8
+
9
+ MODEL_NAME = "Salesforce/blip-image-captioning-base"
10
+ MAX_RETRIES = 3
11
+ RETRY_DELAY = 1 # seconds
12
+ MAX_LENGTH = 50 # Maximum length for generated captions
13
+
14
+ class ImageCaptioner:
15
+ def __init__(self):
16
+ self.processor = None
17
+ self.model = None
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ logging.info(f"Using device: {self.device} for caption model")
20
+ self._initialize_model()
21
+
22
+ def _initialize_model(self):
23
+ for attempt in range(MAX_RETRIES):
24
+ try:
25
+ # Clear CUDA cache if using GPU
26
+ if torch.cuda.is_available():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ self.processor = BlipProcessor.from_pretrained(MODEL_NAME)
31
+ self.model = BlipForConditionalGeneration.from_pretrained(MODEL_NAME).to(self.device)
32
+
33
+ # Verify model loaded correctly
34
+ if self.model is None or self.processor is None:
35
+ raise RuntimeError("Caption model or processor initialization failed")
36
+
37
+ # Set model to evaluation mode
38
+ self.model.eval()
39
+
40
+ logging.info(f"Caption model loaded successfully on {self.device} (attempt {attempt + 1})")
41
+ return
42
+
43
+ except Exception as e:
44
+ logging.error(f"Attempt {attempt + 1} failed to load caption model: {str(e)}")
45
+ if attempt < MAX_RETRIES - 1:
46
+ time.sleep(RETRY_DELAY)
47
+ continue
48
+ raise RuntimeError(f"Failed to initialize the image captioning model after {MAX_RETRIES} attempts")
49
+
50
+ def validate_image(self, image: Image.Image) -> Optional[str]:
51
+ """Validate image before processing"""
52
+ if not isinstance(image, Image.Image):
53
+ return "Input must be a PIL Image"
54
+
55
+ # Check image mode
56
+ if image.mode not in ('RGB', 'L'):
57
+ return "Image must be in RGB or grayscale format"
58
+
59
+ return None
60
+
61
+ def generate_caption(self, image: Image.Image) -> Dict[str, Any]:
62
+ # Validate input
63
+ error = self.validate_image(image)
64
+ if error:
65
+ raise ValueError(error)
66
+
67
+ # Check model initialization
68
+ if self.model is None or self.processor is None:
69
+ self._initialize_model() # Try to reinitialize if models are not loaded
70
+
71
+ try:
72
+ # Clear CUDA cache if using GPU
73
+ if torch.cuda.is_available():
74
+ torch.cuda.empty_cache()
75
+ gc.collect()
76
+
77
+ # Prepare inputs
78
+ inputs = self.processor(image, return_tensors="pt")
79
+ inputs = {k: v.to(self.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
80
+
81
+ # Process with error handling and memory management
82
+ try:
83
+ with torch.no_grad():
84
+ # Generate caption with parameters for better quality
85
+ out = self.model.generate(
86
+ **inputs,
87
+ max_length=MAX_LENGTH,
88
+ num_beams=5, # Beam search for better quality
89
+ temperature=1.0,
90
+ top_k=50,
91
+ top_p=0.95,
92
+ repetition_penalty=1.2,
93
+ length_penalty=1.0,
94
+ no_repeat_ngram_size=2
95
+ )
96
+ caption = self.processor.decode(out[0], skip_special_tokens=True)
97
+
98
+ # Process the caption
99
+ caption = caption.strip()
100
+ # Ensure caption starts with capital letter and ends with period
101
+ caption = caption[0].upper() + caption[1:]
102
+ if not caption.endswith(('.', '!', '?')):
103
+ caption += '.'
104
+
105
+ return {
106
+ "caption": caption,
107
+ "status": "success",
108
+ "model_info": {
109
+ "device": self.device,
110
+ "model_name": MODEL_NAME
111
+ }
112
+ }
113
+ finally:
114
+ # Clean up tensors
115
+ if torch.cuda.is_available():
116
+ torch.cuda.empty_cache()
117
+ gc.collect()
118
+
119
+ except Exception as e:
120
+ logging.error(f"Error during caption generation: {str(e)}")
121
+ raise RuntimeError(f"Failed to generate caption: {str(e)}")
122
+
123
+ # Initialize model
124
+ captioner = ImageCaptioner()
app/model.py CHANGED
@@ -1,24 +1,108 @@
1
- from transformers import ViTImageProcessor, ViTForImageClassification
2
  import torch
3
  from PIL import Image
 
 
 
 
4
 
5
- MODEL_NAME = "google/vit-base-patch16-224"
 
 
 
6
 
7
- # Load once at startup
8
- processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
9
- model = ViTForImageClassification.from_pretrained(MODEL_NAME)
 
 
 
 
10
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def analyze_image(image: Image.Image):
13
- inputs = processor(images=image, return_tensors="pt")
14
- with torch.no_grad():
15
- outputs = model(**inputs)
16
- logits = outputs.logits
17
- predicted_class_idx = logits.argmax(-1).item()
18
- label = model.config.id2label[predicted_class_idx]
19
- confidence = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
20
-
21
- return {
22
- "label": label,
23
- "confidence": round(confidence, 4)
24
- }
 
1
+ from transformers import CLIPProcessor, CLIPModel
2
  import torch
3
  from PIL import Image
4
+ import logging
5
+ import time
6
+ from typing import Dict, Any
7
+ import gc
8
 
9
+ MODEL_NAME = "openai/clip-vit-base-patch16"
10
+ CATEGORIES = ["food", "fitness", "healthcare"]
11
+ MAX_RETRIES = 3
12
+ RETRY_DELAY = 1 # seconds
13
 
14
+ class ImageAnalyzer:
15
+ def __init__(self):
16
+ self.processor = None
17
+ self.model = None
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ logging.info(f"Using device: {self.device}")
20
+ self._initialize_model()
21
 
22
+ def _initialize_model(self):
23
+ for attempt in range(MAX_RETRIES):
24
+ try:
25
+ # Clear CUDA cache if using GPU
26
+ if torch.cuda.is_available():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
 
30
+ self.processor = CLIPProcessor.from_pretrained(MODEL_NAME)
31
+ self.model = CLIPModel.from_pretrained(MODEL_NAME).to(self.device)
32
+
33
+ # Verify model loaded correctly
34
+ if self.model is None or self.processor is None:
35
+ raise RuntimeError("Model or processor initialization failed")
36
+
37
+ logging.info(f"Model loaded successfully on {self.device} (attempt {attempt + 1})")
38
+ return
39
+
40
+ except Exception as e:
41
+ logging.error(f"Attempt {attempt + 1} failed to load model: {str(e)}")
42
+ if attempt < MAX_RETRIES - 1:
43
+ time.sleep(RETRY_DELAY)
44
+ continue
45
+ raise RuntimeError(f"Failed to initialize the image analysis model after {MAX_RETRIES} attempts")
46
+
47
+ def analyze_image(self, image: Image.Image) -> Dict[str, Any]:
48
+ if not isinstance(image, Image.Image):
49
+ raise ValueError("Input must be a PIL Image")
50
+
51
+ if self.model is None or self.processor is None:
52
+ self._initialize_model() # Try to reinitialize if models are not loaded
53
+
54
+ try:
55
+ # Clear CUDA cache if using GPU
56
+ if torch.cuda.is_available():
57
+ torch.cuda.empty_cache()
58
+ gc.collect()
59
+
60
+ # Prepare inputs for CLIP
61
+ inputs = self.processor(
62
+ text=CATEGORIES,
63
+ images=image,
64
+ return_tensors="pt",
65
+ padding=True
66
+ )
67
+
68
+ # Move inputs to the same device as model
69
+ inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
70
+ for k, v in inputs.items()}
71
+
72
+ # Process with error handling and memory management
73
+ try:
74
+ with torch.no_grad():
75
+ outputs = self.model(**inputs)
76
+ logits_per_image = outputs.logits_per_image
77
+ probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
78
+
79
+ # Get top 2 predictions for more informative results
80
+ top_indices = probs.argsort()[-2:][::-1]
81
+ predictions = [
82
+ {
83
+ "category": CATEGORIES[idx],
84
+ "confidence": round(float(probs[idx]), 4)
85
+ }
86
+ for idx in top_indices
87
+ ]
88
+
89
+ return {
90
+ "primary_prediction": predictions[0],
91
+ "alternative_prediction": predictions[1],
92
+ "status": "success"
93
+ }
94
+ finally:
95
+ # Clean up tensors
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+ gc.collect()
99
+ except Exception as e:
100
+ logging.error(f"Error during image analysis: {str(e)}")
101
+ raise RuntimeError(f"Failed to analyze image: {str(e)}")
102
+
103
+ # Create a single instance to be used by the API
104
+ analyzer = ImageAnalyzer()
105
+
106
+ # Function to be used by the API
107
  def analyze_image(image: Image.Image):
108
+ return analyzer.analyze_image(image)
 
 
 
 
 
 
 
 
 
 
 
app/utils.py CHANGED
@@ -1,7 +1,34 @@
1
- from fastapi import UploadFile
2
  from PIL import Image
3
  import io
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def read_image(upload_file: UploadFile) -> Image.Image:
6
- image_bytes = upload_file.file.read()
7
- return Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import UploadFile, HTTPException
2
  from PIL import Image
3
  import io
4
+ import logging
5
+ from typing import Tuple
6
+
7
+ def validate_image_size(image: Image.Image) -> Tuple[bool, str]:
8
+ """Basic image validation"""
9
+ try:
10
+ # Just verify that we can get the image size
11
+ _ = image.size
12
+ return True, ""
13
+ except Exception as e:
14
+ return False, "Invalid image format"
15
 
16
  def read_image(upload_file: UploadFile) -> Image.Image:
17
+ """Read and validate image from uploaded file"""
18
+ try:
19
+ # Read image directly
20
+ image_bytes = upload_file.file.read()
21
+ image = Image.open(io.BytesIO(image_bytes))
22
+
23
+ # Convert to RGB if needed
24
+ if image.mode not in ('RGB', 'L'):
25
+ image = image.convert('RGB')
26
+
27
+ return image
28
+
29
+ except IOError as e:
30
+ logging.error(f"Failed to read image: {str(e)}")
31
+ raise HTTPException(status_code=400, detail="Invalid image format")
32
+ except Exception as e:
33
+ logging.error(f"Unexpected error reading image: {str(e)}")
34
+ raise HTTPException(status_code=500, detail="Failed to process image")
requirements.txt CHANGED
@@ -3,3 +3,4 @@ uvicorn
3
  transformers
4
  torch
5
  Pillow
 
 
3
  transformers
4
  torch
5
  Pillow
6
+ python-multipart