Yassine854 commited on
Commit
551788c
Β·
1 Parent(s): 46d3077

Add application file

Browse files
Files changed (5) hide show
  1. Dockerfile +24 -0
  2. app.py +42 -0
  3. best_model.pth +3 -0
  4. predictor.py +198 -0
  5. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python slim image
2
+ FROM python:3.11-slim
3
+
4
+ # Create a non-root user
5
+ RUN useradd -m -u 1000 user
6
+ USER user
7
+ ENV PATH="/home/user/.local/bin:$PATH"
8
+
9
+ # Set working directory
10
+ WORKDIR /app
11
+
12
+ # Copy and install dependencies
13
+ COPY --chown=user ./requirements.txt requirements.txt
14
+ RUN pip install --no-cache-dir --upgrade pip \
15
+ && pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy app files
18
+ COPY --chown=user . /app
19
+
20
+ # Expose default HF Spaces port
21
+ EXPOSE 7860
22
+
23
+ # Run FastAPI app
24
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from fastapi import FastAPI, File, Form, UploadFile
4
+ from fastapi.responses import JSONResponse
5
+ from predictor import SimilarityPredictor # <-- move your model code to predictor.py
6
+
7
+ # Load model once at startup
8
+ MODEL_PATH = os.getenv("MODEL_PATH", "best_model.pth")
9
+ THRESHOLD = float(os.getenv("THRESHOLD", 0.5))
10
+ predictor = SimilarityPredictor(MODEL_PATH, threshold=THRESHOLD)
11
+
12
+ app = FastAPI()
13
+
14
+ @app.post("/predict")
15
+ async def predict(
16
+ file: UploadFile = File(...),
17
+ text: str = Form(...)
18
+ ):
19
+ try:
20
+ # Save uploaded image temporarily
21
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
22
+ tmp_path = tmp.name
23
+ content = await file.read()
24
+ tmp.write(content)
25
+
26
+ # Run prediction
27
+ result = predictor.predict_similarity(tmp_path, text, verbose=False)
28
+
29
+ # Cleanup temp file
30
+ os.remove(tmp_path)
31
+
32
+ if result is None:
33
+ return JSONResponse({"error": "Prediction failed"}, status_code=500)
34
+
35
+ return JSONResponse(result)
36
+
37
+ except Exception as e:
38
+ return JSONResponse({"error": str(e)}, status_code=500)
39
+
40
+ @app.get("/")
41
+ def home():
42
+ return {"message": "Image-Text Similarity API is running"}
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e37719d6b5bb683844b34047ae06ab05738be9d3648e5a84ca4ec79bed30e4cd
3
+ size 823288223
predictor.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import AutoProcessor, SiglipModel
6
+ from PIL import Image
7
+ import re
8
+
9
+ # Configuration
10
+ class Config:
11
+ image_size = 224
12
+ embed_dim = 512
13
+ temperature = 0.07
14
+ dropout_rate = 0.1
15
+
16
+ def filter_single_characters(text):
17
+ """Enhanced text filtering for both English and Arabic"""
18
+ if not isinstance(text, str):
19
+ text = str(text)
20
+
21
+ words = text.split()
22
+ filtered_words = []
23
+
24
+ for word in words:
25
+ clean_word = re.sub(r'[^\w]', '', word)
26
+
27
+ if len(clean_word) == 1:
28
+ is_arabic = bool(re.match(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\uFB50-\uFDFF\uFE70-\uFEFF]', clean_word))
29
+ is_alpha = clean_word.isalpha()
30
+
31
+ if is_arabic or is_alpha:
32
+ continue
33
+
34
+ if len(clean_word) > 1 or (len(clean_word) == 1 and not clean_word.isalpha()):
35
+ filtered_words.append(word)
36
+
37
+ filtered_text = ' '.join(filtered_words).strip()
38
+ return filtered_text if filtered_text else text
39
+
40
+ class EnhancedSigLIP(nn.Module):
41
+ def __init__(self, model_name="google/siglip-base-patch16-224"):
42
+ super().__init__()
43
+ self.model = SiglipModel.from_pretrained(model_name)
44
+ self.temperature = nn.Parameter(torch.tensor(Config.temperature))
45
+
46
+ # Enhanced projection heads
47
+ self.text_proj = nn.Sequential(
48
+ nn.Linear(self.model.config.text_config.hidden_size, Config.embed_dim * 2),
49
+ nn.GELU(),
50
+ nn.Dropout(Config.dropout_rate),
51
+ nn.Linear(Config.embed_dim * 2, Config.embed_dim)
52
+ )
53
+
54
+ self.vision_proj = nn.Sequential(
55
+ nn.Linear(self.model.config.vision_config.hidden_size, Config.embed_dim * 2),
56
+ nn.GELU(),
57
+ nn.Dropout(Config.dropout_rate),
58
+ nn.Linear(Config.embed_dim * 2, Config.embed_dim)
59
+ )
60
+
61
+ def forward(self, input_ids, attention_mask, pixel_values):
62
+ outputs = self.model(
63
+ input_ids=input_ids,
64
+ attention_mask=attention_mask,
65
+ pixel_values=pixel_values
66
+ )
67
+
68
+ text_embeds = F.normalize(self.text_proj(outputs.text_embeds), p=2, dim=-1)
69
+ image_embeds = F.normalize(self.vision_proj(outputs.image_embeds), p=2, dim=-1)
70
+
71
+ return text_embeds, image_embeds
72
+
73
+ class SimilarityPredictor:
74
+ def __init__(self, model_path, threshold=0.5):
75
+ """
76
+ Initialize the predictor
77
+
78
+ Args:
79
+ model_path: Path to the trained model (.pth file)
80
+ threshold: Similarity threshold for classification (default: 0.5)
81
+ """
82
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ self.threshold = threshold
84
+
85
+ print(f"Using device: {self.device}")
86
+ print("Loading processor...")
87
+ self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
88
+
89
+ print("Loading model...")
90
+ self.model = EnhancedSigLIP().to(self.device)
91
+
92
+ try:
93
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
94
+ print("βœ… Model loaded successfully!")
95
+ except Exception as e:
96
+ print(f"❌ Error loading model: {e}")
97
+ raise
98
+
99
+ self.model.eval()
100
+
101
+ def predict_similarity(self, image_path, text, verbose=True):
102
+ """
103
+ Predict if image and text are similar
104
+
105
+ Args:
106
+ image_path: Path to the image file
107
+ text: Text description
108
+ verbose: Whether to print detailed results
109
+
110
+ Returns:
111
+ dict: Contains similarity score, prediction, and confidence
112
+ """
113
+ try:
114
+ # Load and process image
115
+ if not os.path.exists(image_path):
116
+ raise FileNotFoundError(f"Image not found: {image_path}")
117
+
118
+ image = Image.open(image_path).convert('RGB')
119
+
120
+ # Process text
121
+ original_text = str(text).strip()
122
+ filtered_text = filter_single_characters(original_text)
123
+
124
+ if verbose:
125
+ print(f"πŸ“ Original text: '{original_text}'")
126
+ if original_text != filtered_text:
127
+ print(f"πŸ” Filtered text: '{filtered_text}'")
128
+ print(f"πŸ–ΌοΈ Image: {image_path}")
129
+
130
+ # Process inputs
131
+ inputs = self.processor(
132
+ text=filtered_text,
133
+ images=image,
134
+ return_tensors="pt",
135
+ padding="max_length",
136
+ truncation=True,
137
+ max_length=64
138
+ )
139
+
140
+ # Move to device
141
+ input_ids = inputs['input_ids'].to(self.device)
142
+ pixel_values = inputs['pixel_values'].to(self.device)
143
+
144
+ if 'attention_mask' in inputs:
145
+ attention_mask = inputs['attention_mask'].to(self.device)
146
+ else:
147
+ attention_mask = (input_ids != self.processor.tokenizer.pad_token_id).long()
148
+
149
+ # Get predictions
150
+ with torch.no_grad():
151
+ text_embeds, image_embeds = self.model(input_ids, attention_mask, pixel_values)
152
+
153
+ # Calculate similarity
154
+ similarity = torch.dot(text_embeds[0], image_embeds[0]).item()
155
+
156
+ # Make prediction
157
+ prediction = similarity > self.threshold
158
+ confidence = abs(similarity - self.threshold)
159
+
160
+ result = {
161
+ 'similarity_score': similarity,
162
+ 'prediction': 'MATCH' if prediction else 'NO MATCH',
163
+ 'is_match': prediction,
164
+ 'confidence': confidence,
165
+ 'threshold': self.threshold
166
+ }
167
+
168
+ if verbose:
169
+ print(f"\n🎯 Results:")
170
+ print(f" Similarity Score: {similarity:.4f}")
171
+ print(f" Threshold: {self.threshold}")
172
+ print(f" Prediction: {result['prediction']}")
173
+ print(f" Confidence: {confidence:.4f}")
174
+
175
+ if prediction:
176
+ print("βœ… The image and text are SIMILAR!")
177
+ else:
178
+ print("❌ The image and text are NOT similar.")
179
+
180
+ return result
181
+
182
+ except Exception as e:
183
+ print(f"❌ Error during prediction: {e}")
184
+ return None
185
+
186
+ def quick_test(model_path, image_path, text, threshold=0.5):
187
+ """
188
+ Quick function to test a single image-text pair
189
+
190
+ Args:
191
+ model_path: Path to your trained model
192
+ image_path: Path to the image
193
+ text: Text description
194
+ threshold: Similarity threshold (default: 0.5)
195
+ """
196
+ predictor = SimilarityPredictor(model_path, threshold)
197
+ result = predictor.predict_similarity(image_path, text)
198
+ return result
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers
5
+ sentencepiece
6
+ protobuf
7
+ Pillow
8
+ python-multipart
9
+ aiofiles
10
+ regex
11
+ numpy