kevalgajjar commited on
Commit
682cf73
·
verified ·
1 Parent(s): 980b8ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -111
app.py CHANGED
@@ -1,6 +1,3 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from PIL import Image, ImageEnhance, ImageFilter
@@ -12,7 +9,6 @@ import os
12
  import base64
13
  from io import BytesIO
14
  import math
15
- import uvicorn
16
 
17
  # Global storage
18
  pdf_texts = {}
@@ -20,30 +16,6 @@ reader = None
20
  blip_processor = None
21
  blip_model = None
22
 
23
- # === FASTAPI APP ===
24
- app = FastAPI(title="AI Assistant API")
25
-
26
- # Add CORS
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"],
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- # Request model
36
- class AnalyzeRequest(BaseModel):
37
- message: str
38
- image_base64: str = None
39
- model: str = "Qwen/Qwen2.5-7B-Instruct"
40
-
41
- # Response model
42
- class AnalyzeResponse(BaseModel):
43
- analysis: str
44
-
45
- # === HELPER FUNCTIONS ===
46
-
47
  def load_pdfs():
48
  global pdf_texts
49
  pdf_texts.clear()
@@ -98,6 +70,14 @@ def decode_base64_image(image_data):
98
  pass
99
  return image_data
100
 
 
 
 
 
 
 
 
 
101
  def analyze_image(image):
102
  initialize_vision_models()
103
  try:
@@ -116,98 +96,125 @@ def analyze_image(image):
116
  except Exception as e:
117
  return "", str(e)
118
 
119
- # === FASTAPI ENDPOINTS ===
120
-
121
- @app.get("/")
122
- def read_root():
123
- return {"message": "AI Assistant API is running", "endpoints": ["/api/analyze", "/gradio"]}
 
 
124
 
125
- @app.get("/health")
126
- def health_check():
127
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- @app.post("/api/analyze", response_model=AnalyzeResponse)
130
- async def analyze_endpoint(request: AnalyzeRequest):
131
  """
132
- REST API endpoint for analyzing text/images
133
  """
134
- try:
135
- token = os.getenv('HF_TOKEN')
136
- message = request.message
137
- image_base64 = request.image_base64
138
- model = request.model
139
-
140
- # Decode image
141
- img = None
142
- if image_base64:
143
- img = decode_base64_image(image_base64)
144
-
145
- # Detect MCQ
146
- has_options = bool(re.search(r'[A-D][\.\)]\s', message))
147
-
148
- # Get context from image
149
- context = ""
150
- if img:
151
- try:
152
- ocr_text, _ = analyze_image(img)
153
- if ocr_text:
154
- context = f"\n\nExtracted text:\n{ocr_text[:400]}"
155
- except Exception as e:
156
- context = f"\n\n(OCR error: {str(e)})"
157
-
158
- # System message
159
- if has_options:
160
- sys_msg = "Exam assistant. Format: Answer: [letter]. Reason: [one sentence]."
161
- temp = 0.2
162
- tokens = 100
163
- else:
164
- sys_msg = "You are a helpful AI assistant."
165
- temp = 0.6
166
- tokens = 400
167
-
168
  try:
169
- client = InferenceClient(token=token, model=model)
170
- except:
171
- try:
172
- client = InferenceClient(token=token, model="Qwen/Qwen2.5-7B-Instruct")
173
- except Exception as e:
174
- raise HTTPException(status_code=500, detail=f"Model connection failed: {str(e)}")
175
-
176
- messages = [
177
- {"role": "system", "content": sys_msg},
178
- {"role": "user", "content": message + context}
179
- ]
180
-
181
- # Non-streaming response
182
- response = ""
 
 
 
 
 
183
  try:
184
- for msg in client.chat_completion(messages, max_tokens=tokens, stream=True, temperature=temp, top_p=0.9):
185
- if msg.choices and msg.choices[0].delta.content:
186
- response += msg.choices[0].delta.content
187
- if has_options and len(response) > 250:
188
- break
189
  except Exception as e:
190
- raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
191
-
192
- return AnalyzeResponse(analysis=response.strip())
193
-
194
- except HTTPException:
195
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  except Exception as e:
197
- raise HTTPException(status_code=500, detail=str(e))
198
-
199
- # === GRADIO UI ===
200
 
201
- def respond(message, history, system_message, max_tokens, temperature, top_p, model_selection, image, hf_token):
202
- """Gradio chat function"""
 
 
 
 
 
 
 
 
 
 
 
203
  if image is not None:
204
  image = decode_base64_image(image)
205
 
206
  token = os.getenv('HF_TOKEN') or (hf_token.strip() if hf_token else None)
207
  has_options = bool(re.search(r'[A-D][\.\)]\s', message))
 
 
 
 
 
 
208
 
209
  try:
210
- client = InferenceClient(token=token, model=model_selection)
211
  except:
212
  try:
213
  client = InferenceClient(token=token, model="Qwen/Qwen2.5-7B-Instruct")
@@ -225,7 +232,7 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, mo
225
  pass
226
 
227
  if has_options:
228
- system_message = "Exam assistant. Answer: [letter]. Reason: [one sentence]."
229
  temperature = 0.2
230
  max_tokens = 100
231
 
@@ -246,7 +253,8 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, mo
246
  # Load PDFs
247
  pdf_status = load_pdfs()
248
 
249
- # Create Gradio interface
 
250
  chat_interface = gr.ChatInterface(
251
  respond,
252
  type="messages",
@@ -256,7 +264,7 @@ chat_interface = gr.ChatInterface(
256
  gr.Slider(0.1, 1.2, 0.6, step=0.1, label="Temperature"),
257
  gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p"),
258
  gr.Dropdown(
259
- choices=["Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "HuggingFaceH4/zephyr-7b-beta"],
260
  value="Qwen/Qwen2.5-7B-Instruct",
261
  label="Model",
262
  ),
@@ -267,10 +275,26 @@ chat_interface = gr.ChatInterface(
267
  description=f"MCQ (short) • Math (steps) • General (detailed)\n\n{pdf_status}",
268
  )
269
 
270
- # Mount Gradio to FastAPI
271
- app = gr.mount_gradio_app(app, chat_interface, path="/gradio")
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # === RUN SERVER ===
 
 
 
 
 
274
 
275
  if __name__ == "__main__":
276
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from PIL import Image, ImageEnhance, ImageFilter
 
9
  import base64
10
  from io import BytesIO
11
  import math
 
12
 
13
  # Global storage
14
  pdf_texts = {}
 
16
  blip_processor = None
17
  blip_model = None
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def load_pdfs():
20
  global pdf_texts
21
  pdf_texts.clear()
 
70
  pass
71
  return image_data
72
 
73
+ def web_search(query):
74
+ try:
75
+ from ddgs import DDGS
76
+ results = DDGS().text(query, max_results=2)
77
+ return "\n".join([f"{r['title']}: {r['body'][:100]}" for r in results])
78
+ except:
79
+ return None
80
+
81
  def analyze_image(image):
82
  initialize_vision_models()
83
  try:
 
96
  except Exception as e:
97
  return "", str(e)
98
 
99
+ def extract_math_calcs(text):
100
+ calcs = []
101
+ for match in re.finditer(r'C\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)', text):
102
+ n, k = int(match.group(1)), int(match.group(2))
103
+ result = math.comb(n, k)
104
+ calcs.append(f"C({n},{k})={result:,}")
105
+ return calcs
106
 
107
+ def get_pdf_context(query):
108
+ if not pdf_texts:
109
+ return None
110
+ keywords = set(re.findall(r'\b\w{4,}\b', query.lower()))
111
+ chunks = []
112
+ for path, text in pdf_texts.items():
113
+ for sent in text.split('.')[:40]:
114
+ score = sum(1 for kw in keywords if kw in sent.lower())
115
+ if score > 0:
116
+ chunks.append((score, sent[:150]))
117
+ chunks.sort(reverse=True)
118
+ if chunks and chunks[0][0] >= 2:
119
+ return chunks[0][1]
120
+ return None
121
 
122
+ # MAIN API FUNCTION - Simple interface for external calls
123
+ def api_analyze(message: str, image_base64: str = None, model: str = "Qwen/Qwen2.5-7B-Instruct"):
124
  """
125
+ Simple API function for external calls
126
  """
127
+ token = os.getenv('HF_TOKEN')
128
+
129
+ # Decode image
130
+ img = None
131
+ if image_base64:
132
+ img = decode_base64_image(image_base64)
133
+
134
+ # Detect MCQ
135
+ has_options = bool(re.search(r'[A-D][\.\)]\s', message))
136
+
137
+ # Get context from image
138
+ context = ""
139
+ if img:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  try:
141
+ ocr_text, _ = analyze_image(img)
142
+ if ocr_text:
143
+ context = f"\n\nExtracted text from image:\n{ocr_text[:400]}"
144
+ except Exception as e:
145
+ context = f"\n\n(Image processing error: {str(e)})"
146
+
147
+ # System message
148
+ if has_options:
149
+ sys_msg = "You are an exam assistant. For MCQ, give: Answer: [letter]. Reason: [one sentence only]."
150
+ temp = 0.2
151
+ tokens = 100
152
+ else:
153
+ sys_msg = "You are a helpful AI assistant."
154
+ temp = 0.6
155
+ tokens = 400
156
+
157
+ try:
158
+ client = InferenceClient(token=token, model=model)
159
+ except:
160
  try:
161
+ client = InferenceClient(token=token, model="Qwen/Qwen2.5-7B-Instruct")
 
 
 
 
162
  except Exception as e:
163
+ return f"Error connecting to model: {str(e)}"
164
+
165
+ messages = [
166
+ {"role": "system", "content": sys_msg},
167
+ {"role": "user", "content": message + context}
168
+ ]
169
+
170
+ # Non-streaming response
171
+ response = ""
172
+ try:
173
+ for msg in client.chat_completion(
174
+ messages,
175
+ max_tokens=tokens,
176
+ stream=True,
177
+ temperature=temp,
178
+ top_p=0.9
179
+ ):
180
+ if msg.choices and msg.choices[0].delta.content:
181
+ response += msg.choices[0].delta.content
182
+
183
+ # Stop early for MCQ
184
+ if has_options and len(response) > 200:
185
+ break
186
  except Exception as e:
187
+ return f"Error during inference: {str(e)}"
188
+
189
+ return response.strip()
190
 
191
+ # Chat function for UI
192
+ def respond(
193
+ message,
194
+ history: list[dict[str, str]],
195
+ system_message,
196
+ max_tokens,
197
+ temperature,
198
+ top_p,
199
+ model_selection,
200
+ image,
201
+ hf_token,
202
+ ):
203
+ """UI chat function with streaming"""
204
  if image is not None:
205
  image = decode_base64_image(image)
206
 
207
  token = os.getenv('HF_TOKEN') or (hf_token.strip() if hf_token else None)
208
  has_options = bool(re.search(r'[A-D][\.\)]\s', message))
209
+ is_math_calc = any(w in message.lower() for w in ['calculate', 'factorial', 'combination'])
210
+
211
+ if is_math_calc and not has_options:
212
+ selected_model = "Qwen/Qwen2.5-Math-7B-Instruct"
213
+ else:
214
+ selected_model = model_selection
215
 
216
  try:
217
+ client = InferenceClient(token=token, model=selected_model)
218
  except:
219
  try:
220
  client = InferenceClient(token=token, model="Qwen/Qwen2.5-7B-Instruct")
 
232
  pass
233
 
234
  if has_options:
235
+ system_message = "Exam assistant. MCQ format: Answer: [letter]. Reason: [one sentence]."
236
  temperature = 0.2
237
  max_tokens = 100
238
 
 
253
  # Load PDFs
254
  pdf_status = load_pdfs()
255
 
256
+ # Create TWO separate interfaces
257
+ # 1. Chat UI for users
258
  chat_interface = gr.ChatInterface(
259
  respond,
260
  type="messages",
 
264
  gr.Slider(0.1, 1.2, 0.6, step=0.1, label="Temperature"),
265
  gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p"),
266
  gr.Dropdown(
267
+ choices=["Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "HuggingFaceH4/zephyr-7b-beta","openai/gpt-oss-20b","Qwen/Qwen2.5-Math-7B-Instruct"],
268
  value="Qwen/Qwen2.5-7B-Instruct",
269
  label="Model",
270
  ),
 
275
  description=f"MCQ (short) • Math (steps) • General (detailed)\n\n{pdf_status}",
276
  )
277
 
278
+ # 2. Simple API interface
279
+ api_interface = gr.Interface(
280
+ fn=api_analyze,
281
+ inputs=[
282
+ gr.Textbox(label="Message", placeholder="Enter your question"),
283
+ gr.Textbox(label="Image (base64)", placeholder="Optional base64 image"),
284
+ gr.Textbox(label="Model", value="Qwen/Qwen2.5-7B-Instruct"),
285
+ ],
286
+ outputs=gr.Textbox(label="Response"),
287
+ title="API Endpoint",
288
+ description="Direct API access",
289
+ api_name="analyze" # Creates /call/analyze endpoint
290
+ )
291
 
292
+ # Combine both in tabs
293
+ demo = gr.TabbedInterface(
294
+ [chat_interface, api_interface],
295
+ ["Chat", "API"],
296
+ title="🤖 AI Assistant"
297
+ )
298
 
299
  if __name__ == "__main__":
300
+ demo.launch()