gopichandra commited on
Commit
b12e335
·
verified ·
1 Parent(s): b00fdd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -1,33 +1,24 @@
1
- import os
2
- import sys
3
- from subprocess import check_call
4
-
5
- # Ensure transformers is installed
6
- try:
7
- from transformers import DetrImageProcessor, DetrForObjectDetection, pipeline
8
- except ImportError:
9
- check_call([sys.executable, "-m", "pip", "install", "transformers==4.33.2"])
10
- from transformers import DetrImageProcessor, DetrForObjectDetection, pipeline
11
-
12
  from PIL import Image
13
  import gradio as gr
14
 
15
  # Load pre-trained models
16
  detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
17
  detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
18
- description_generator = pipeline("text-generation", model="gpt-2")
19
 
20
- # Function to recognize and describe a product
21
  def recognize_and_describe(image):
 
22
  inputs = detection_processor(images=image, return_tensors="pt")
23
  outputs = detection_model(**inputs)
24
- logits = outputs.logits.argmax(-1).item()
25
- product_label = f"Product Class: {logits}"
26
-
27
- # Generate description
28
- prompt = f"Describe the product: {product_label}"
29
- description = description_generator(prompt, max_length=50, num_return_sequences=1)
30
- return product_label, description[0]["generated_text"]
 
31
 
32
  # Gradio Interface
33
  interface = gr.Interface(
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection, pipeline
 
 
 
 
 
 
 
 
 
 
2
  from PIL import Image
3
  import gradio as gr
4
 
5
  # Load pre-trained models
6
  detection_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
7
  detection_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
 
8
 
9
+ # Function to process image and generate description
10
  def recognize_and_describe(image):
11
+ # Process the image with DETR
12
  inputs = detection_processor(images=image, return_tensors="pt")
13
  outputs = detection_model(**inputs)
14
+
15
+ # Get detected classes
16
+ logits = outputs.logits.argmax(-1).tolist()[0]
17
+ product_label = f"Detected Product Class: {logits}"
18
+
19
+ # Generate a description using a dummy model or hardcoded description
20
+ description = f"This is a product in class {logits}. Further information can be retrieved."
21
+ return product_label, description
22
 
23
  # Gradio Interface
24
  interface = gr.Interface(