tayyab-077 commited on
Commit
a66fd49
·
1 Parent(s): 2b903c0

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -32
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # ------------------------------
2
- # AI Multi‑Modal Assistant — Enhanced Version
3
  # ------------------------------
4
 
5
  import gradio as gr
@@ -11,8 +11,9 @@ import pandas as pd
11
  from reportlab.lib.pagesizes import letter
12
  from reportlab.pdfgen import canvas
13
  import io
14
- import yake # keyword extraction
15
  import tempfile
 
16
 
17
  # ------------------------------
18
  # 1. Load Models & Labels
@@ -23,9 +24,6 @@ sentiment_model = pipeline(
23
  "sentiment-analysis",
24
  model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
25
  )
26
-
27
-
28
- # Summarization Model
29
  summarizer_model = pipeline("summarization", model="facebook/bart-large-cnn")
30
 
31
  # Image classification model
@@ -40,9 +38,8 @@ preprocess = transforms.Compose(
40
  ]
41
  )
42
 
43
- # Load ImageNet class labels mapping
44
- imagenet_labels = []
45
- with open("imagenet_classes.txt", "r") as f: # ensure this file is in your folder
46
  imagenet_labels = [s.strip() for s in f.readlines()]
47
 
48
  # Keyword extraction
@@ -52,7 +49,6 @@ kw_extractor = yake.KeywordExtractor(lan="en", top=5)
52
  # 2. Helper Functions
53
  # ------------------------------
54
 
55
-
56
  def analyze_text(text: str) -> dict:
57
  sentiment = sentiment_model(text)[0]
58
  summary = summarizer_model(
@@ -66,18 +62,18 @@ def analyze_text(text: str) -> dict:
66
  "Keywords": keywords,
67
  }
68
 
69
-
70
  def analyze_image(image: Image.Image) -> dict:
71
  img_t = preprocess(image).unsqueeze(0)
72
  with torch.no_grad():
73
  outputs = image_model(img_t)
74
  class_idx = outputs.argmax().item()
75
- if 0 <= class_idx < len(imagenet_labels):
76
- class_label = imagenet_labels[class_idx]
77
- else:
78
- class_label = f"Class index {class_idx}"
79
  return {"Predicted Class Index": class_idx, "Predicted Class Label": class_label}
80
 
 
 
 
 
81
 
82
  def generate_pdf(results: dict) -> str:
83
  buffer = io.BytesIO()
@@ -96,19 +92,16 @@ def generate_pdf(results: dict) -> str:
96
  c.save()
97
  buffer.seek(0)
98
 
99
- # ✅ Save to a temp file and return path instead of buffer
100
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
101
  tmp.write(buffer.getvalue())
102
  tmp_path = tmp.name
103
 
104
- return tmp_path # ✅ returns file path (Gradio-friendly)
105
-
106
 
107
  # ------------------------------
108
  # 3. Multi‑Modal Analysis Function
109
  # ------------------------------
110
  def analyze(input_data):
111
- # Handles both text and image correctly in Gradio
112
  if isinstance(input_data, str) and input_data.strip():
113
  return analyze_text(input_data)
114
  elif isinstance(input_data, dict) and "image" in input_data:
@@ -118,7 +111,6 @@ def analyze(input_data):
118
  else:
119
  return {"Error": "Please enter text or upload an image."}
120
 
121
-
122
  # ------------------------------
123
  # 4. Gradio UI Layout
124
  # ------------------------------
@@ -126,47 +118,54 @@ def analyze(input_data):
126
  with gr.Blocks() as demo:
127
  gr.Markdown("## AI Multi‑Modal Assistant")
128
 
 
129
  with gr.Tab("Image Analysis"):
130
  image_input = gr.Image(type="pil", label="Upload an image for classification")
131
- analyze_image_button = gr.Button("Analyze Image") # ✅ MOVED INSIDE the tab
132
-
133
  image_output = gr.JSON(label="Image Analysis Results")
134
  pdf_button_image = gr.Button("Download Report (PDF)")
135
 
136
  analyze_image_button.click(fn=analyze, inputs=image_input, outputs=image_output)
 
 
 
 
 
137
 
138
- pdf_button_image.click(
139
- fn=lambda x: generate_pdf(analyze(x)),
140
- inputs=image_input,
141
- outputs=gr.File(label="Download PDF Report"),
142
- )
143
-
144
-
145
  with gr.Tab("Text Analysis"):
146
  text_input = gr.Textbox(
147
  label="Enter text to analyze",
148
  placeholder="Type your text here...",
149
  lines=5
150
  )
151
-
152
  analyze_text_button = gr.Button("Analyze Text")
153
  text_output = gr.JSON(label="Text Analysis Results")
154
  pdf_button_text = gr.Button("Download Report (PDF)")
155
 
156
- # Text analysis events
157
-
158
  analyze_text_button.click(
159
  fn=analyze,
160
  inputs=text_input,
161
  outputs=text_output
162
  )
163
-
164
  pdf_button_text.click(
165
  fn=lambda x: generate_pdf(analyze(x)),
166
  inputs=text_input,
167
  outputs=gr.File(label="Download PDF Report")
168
  )
169
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # ------------------------------
172
  # 5. Launch the App
 
1
  # ------------------------------
2
+ # AI Multi‑Modal Assistant — Phase 2 (OCR Added)
3
  # ------------------------------
4
 
5
  import gradio as gr
 
11
  from reportlab.lib.pagesizes import letter
12
  from reportlab.pdfgen import canvas
13
  import io
14
+ import yake
15
  import tempfile
16
+ import pytesseract # <-- OCR
17
 
18
  # ------------------------------
19
  # 1. Load Models & Labels
 
24
  "sentiment-analysis",
25
  model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
26
  )
 
 
 
27
  summarizer_model = pipeline("summarization", model="facebook/bart-large-cnn")
28
 
29
  # Image classification model
 
38
  ]
39
  )
40
 
41
+ # Load ImageNet class labels
42
+ with open("imagenet_classes.txt", "r") as f:
 
43
  imagenet_labels = [s.strip() for s in f.readlines()]
44
 
45
  # Keyword extraction
 
49
  # 2. Helper Functions
50
  # ------------------------------
51
 
 
52
  def analyze_text(text: str) -> dict:
53
  sentiment = sentiment_model(text)[0]
54
  summary = summarizer_model(
 
62
  "Keywords": keywords,
63
  }
64
 
 
65
  def analyze_image(image: Image.Image) -> dict:
66
  img_t = preprocess(image).unsqueeze(0)
67
  with torch.no_grad():
68
  outputs = image_model(img_t)
69
  class_idx = outputs.argmax().item()
70
+ class_label = imagenet_labels[class_idx] if 0 <= class_idx < len(imagenet_labels) else f"Class index {class_idx}"
 
 
 
71
  return {"Predicted Class Index": class_idx, "Predicted Class Label": class_label}
72
 
73
+ def ocr_image(image: Image.Image) -> dict:
74
+ """Extract text from uploaded image using Tesseract OCR."""
75
+ text = pytesseract.image_to_string(image)
76
+ return {"Extracted Text": text}
77
 
78
  def generate_pdf(results: dict) -> str:
79
  buffer = io.BytesIO()
 
92
  c.save()
93
  buffer.seek(0)
94
 
 
95
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
96
  tmp.write(buffer.getvalue())
97
  tmp_path = tmp.name
98
 
99
+ return tmp_path
 
100
 
101
  # ------------------------------
102
  # 3. Multi‑Modal Analysis Function
103
  # ------------------------------
104
  def analyze(input_data):
 
105
  if isinstance(input_data, str) and input_data.strip():
106
  return analyze_text(input_data)
107
  elif isinstance(input_data, dict) and "image" in input_data:
 
111
  else:
112
  return {"Error": "Please enter text or upload an image."}
113
 
 
114
  # ------------------------------
115
  # 4. Gradio UI Layout
116
  # ------------------------------
 
118
  with gr.Blocks() as demo:
119
  gr.Markdown("## AI Multi‑Modal Assistant")
120
 
121
+ # ------------------ Image Analysis Tab ------------------
122
  with gr.Tab("Image Analysis"):
123
  image_input = gr.Image(type="pil", label="Upload an image for classification")
124
+ analyze_image_button = gr.Button("Analyze Image")
 
125
  image_output = gr.JSON(label="Image Analysis Results")
126
  pdf_button_image = gr.Button("Download Report (PDF)")
127
 
128
  analyze_image_button.click(fn=analyze, inputs=image_input, outputs=image_output)
129
+ pdf_button_image.click(
130
+ fn=lambda x: generate_pdf(analyze(x)),
131
+ inputs=image_input,
132
+ outputs=gr.File(label="Download PDF Report"),
133
+ )
134
 
135
+ # ------------------ Text Analysis Tab ------------------
 
 
 
 
 
 
136
  with gr.Tab("Text Analysis"):
137
  text_input = gr.Textbox(
138
  label="Enter text to analyze",
139
  placeholder="Type your text here...",
140
  lines=5
141
  )
 
142
  analyze_text_button = gr.Button("Analyze Text")
143
  text_output = gr.JSON(label="Text Analysis Results")
144
  pdf_button_text = gr.Button("Download Report (PDF)")
145
 
 
 
146
  analyze_text_button.click(
147
  fn=analyze,
148
  inputs=text_input,
149
  outputs=text_output
150
  )
 
151
  pdf_button_text.click(
152
  fn=lambda x: generate_pdf(analyze(x)),
153
  inputs=text_input,
154
  outputs=gr.File(label="Download PDF Report")
155
  )
156
 
157
+ # ------------------ OCR Tab ------------------
158
+ with gr.Tab("OCR"):
159
+ ocr_input = gr.Image(type="pil", label="Upload image for OCR")
160
+ ocr_output = gr.JSON(label="OCR Results")
161
+ pdf_button_ocr = gr.Button("Download OCR PDF")
162
+
163
+ ocr_input.submit(fn=ocr_image, inputs=ocr_input, outputs=ocr_output)
164
+ pdf_button_ocr.click(
165
+ fn=lambda x: generate_pdf(x),
166
+ inputs=ocr_output,
167
+ outputs=gr.File(label="Download PDF Report"),
168
+ )
169
 
170
  # ------------------------------
171
  # 5. Launch the App