shaheerawan3 commited on
Commit
c241ecd
·
verified ·
1 Parent(s): f8747cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -51
app.py CHANGED
@@ -289,7 +289,7 @@ class ImageScraper:
289
  ]
290
 
291
  def get_images(self, query: str, num_images: int = 15) -> Dict[str, List[Dict[str, str]]]:
292
- """Get images for either single word queries or extract keywords from long prompts"""
293
  try:
294
  # Initialize result structure
295
  result = {
@@ -298,46 +298,33 @@ class ImageScraper:
298
  'general': []
299
  }
300
 
301
- # Extract keywords if query is long
302
- if len(query.split()) > 3:
303
- keywords = self.extract_key_topics(query)
304
- print(f"Extracted keywords: {keywords}") # Debug log
305
- else:
306
- keywords = [query]
307
-
308
- # Fetch images for each keyword
309
- for keyword in keywords:
310
- base_url = "https://pixabay.com/api/"
311
- params = {
312
- 'key': self.PIXABAY_API_KEY,
313
- 'q': keyword,
314
- 'image_type': 'photo',
315
- 'per_page': max(3, num_images // len(keywords)), # Distribute images among keywords
316
- 'safesearch': True,
317
- 'lang': 'en'
318
- }
319
-
320
- response = requests.get(base_url, params=params, headers=self.headers)
321
-
322
- if response.status_code == 200:
323
- data = response.json()
324
- hits = data.get('hits', [])
325
-
326
- for hit in hits:
327
- image_data = {
328
- 'url': hit['largeImageURL'],
329
- 'keyword': keyword,
330
- 'relevance': 'Primary match',
331
- 'tags': hit.get('tags', '')
332
- }
333
-
334
- # Distribute images across categories
335
- if len(result['primary']) < num_images // 3:
336
- result['primary'].append(image_data)
337
- elif len(result['secondary']) < num_images // 3:
338
- result['secondary'].append(image_data)
339
- else:
340
- result['general'].append(image_data)
341
 
342
  # If no images found, use stock images
343
  if not any(result.values()):
@@ -346,23 +333,75 @@ class ImageScraper:
346
  'url': url,
347
  'keyword': 'technology',
348
  'relevance': 'Fallback',
349
- 'tags': 'technology'
 
350
  } for url in stock_images[:num_images]]
351
 
352
  return result
353
 
354
  except Exception as e:
355
  print(f"Error in get_images: {str(e)}")
356
- # Return stock images as fallback
357
- stock_images = self.get_stock_images()
358
- return {
359
- 'general': [{
360
- 'url': url,
361
- 'keyword': 'technology',
362
- 'relevance': 'Fallback',
363
- 'tags': 'technology'
364
- } for url in stock_images[:num_images]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
  def get_fallback_keywords(self) -> List[Dict[str, str]]:
 
289
  ]
290
 
291
  def get_images(self, query: str, num_images: int = 15) -> Dict[str, List[Dict[str, str]]]:
292
+ """Get images with AI-driven selection and ranking"""
293
  try:
294
  # Initialize result structure
295
  result = {
 
298
  'general': []
299
  }
300
 
301
+ # Extract and analyze keywords using AI
302
+ keywords = self.extract_key_topics(query)
303
+ print(f"AI extracted keywords: {keywords}")
304
+
305
+ # Score and rank keywords based on relevance to query
306
+ keyword_scores = self.score_keywords(query, keywords)
307
+ ranked_keywords = sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True)
308
+
309
+ # Fetch and analyze images for each keyword
310
+ all_images = []
311
+ for keyword, score in ranked_keywords:
312
+ images = self.get_images_for_keyword(keyword)
313
+ for img in images:
314
+ img['relevance_score'] = score * self.analyze_image_relevance(img, query)
315
+ all_images.append(img)
316
+
317
+ # Sort images by relevance score
318
+ sorted_images = sorted(all_images, key=lambda x: x['relevance_score'], reverse=True)
319
+
320
+ # Distribute images across categories
321
+ total_images = min(len(sorted_images), num_images)
322
+ primary_count = total_images // 2
323
+ secondary_count = total_images // 3
324
+
325
+ result['primary'] = sorted_images[:primary_count]
326
+ result['secondary'] = sorted_images[primary_count:primary_count + secondary_count]
327
+ result['general'] = sorted_images[primary_count + secondary_count:total_images]
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  # If no images found, use stock images
330
  if not any(result.values()):
 
333
  'url': url,
334
  'keyword': 'technology',
335
  'relevance': 'Fallback',
336
+ 'tags': 'technology',
337
+ 'relevance_score': 0.5
338
  } for url in stock_images[:num_images]]
339
 
340
  return result
341
 
342
  except Exception as e:
343
  print(f"Error in get_images: {str(e)}")
344
+ return self.get_fallback_images(num_images)
345
+
346
+ def score_keywords(self, query: str, keywords: List[str]) -> Dict[str, float]:
347
+ """Score keywords based on relevance to query"""
348
+ scores = {}
349
+ query_words = set(query.lower().split())
350
+
351
+ for keyword in keywords:
352
+ score = 0.0
353
+ keyword_words = set(keyword.lower().split())
354
+
355
+ # Direct word match
356
+ word_matches = len(keyword_words.intersection(query_words))
357
+ score += word_matches * 0.3
358
+
359
+ # Contextual relevance
360
+ context_terms = {
361
+ 'digital': 0.8,
362
+ 'security': 0.7,
363
+ 'legacy': 0.9,
364
+ 'protection': 0.6,
365
+ 'management': 0.5,
366
+ 'AI': 0.8,
367
+ 'technology': 0.6
368
  }
369
+
370
+ for term, weight in context_terms.items():
371
+ if term in keyword.lower():
372
+ score += weight
373
+
374
+ scores[keyword] = min(score, 1.0) # Normalize to 0-1
375
+
376
+ return scores
377
+
378
+ def analyze_image_relevance(self, image: Dict[str, str], query: str) -> float:
379
+ """Analyze image relevance based on tags and metadata"""
380
+ score = 0.0
381
+
382
+ # Analyze tags
383
+ tags = set(image['tags'].lower().split(','))
384
+ query_words = set(query.lower().split())
385
+
386
+ # Tag matching
387
+ matching_tags = len(tags.intersection(query_words))
388
+ score += matching_tags * 0.2
389
+
390
+ # Context relevance
391
+ relevant_terms = {
392
+ 'technology': 0.3,
393
+ 'digital': 0.3,
394
+ 'security': 0.3,
395
+ 'business': 0.2,
396
+ 'professional': 0.2,
397
+ 'modern': 0.1
398
+ }
399
+
400
+ for term, weight in relevant_terms.items():
401
+ if term in tags:
402
+ score += weight
403
+
404
+ return min(score, 1.0) # Normalize to 0-1
405
 
406
 
407
  def get_fallback_keywords(self) -> List[Dict[str, str]]: