Spaces:
Paused
Paused
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse,HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import cv2 | |
| import numpy as np | |
| from pillmodel import get_prediction | |
| import base64 | |
| from fastapi.staticfiles import StaticFiles | |
| import os | |
| import google.generativeai as genai | |
| from google.generativeai.types import HarmCategory, HarmBlockThreshold | |
| import google.ai.generativelanguage as glm | |
| from PIL import Image | |
| import io | |
| import random | |
| import re | |
| import json | |
| api_keys = os.getenv('GEMINI_API_KEYS').split(',') | |
| print(api_keys) | |
| from inference_sdk import InferenceHTTPClient | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def predict(image: UploadFile = File(...)): | |
| contents = await image.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| # Save the image to a temporary location | |
| # temp_image_path = "temp_image.jpg" | |
| # cv2.imwrite(temp_image_path, img) | |
| # Prediction | |
| predicted_image, count_dict = get_prediction(img) | |
| # Encode predicted image to base64 | |
| _, buffer = cv2.imencode('.jpg', predicted_image) | |
| predicted_image_str = base64.b64encode(buffer).decode('utf-8') | |
| # Send a confirmation message | |
| message_to_send = ( | |
| f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. " | |
| f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills." | |
| ) | |
| return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str}) | |
| async def predict_wheat(image: UploadFile = File(...), model_id: str = "grian/1"): | |
| contents = await image.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| # delete the image if exists | |
| try: | |
| os.remove("temp_image.jpg") | |
| except: | |
| print("temp_image.jpg does not exist") | |
| # Save the image to a temporary location | |
| temp_image_path = "temp_image.jpg" | |
| cv2.imwrite(temp_image_path, img) | |
| CLIENT = InferenceHTTPClient( | |
| api_url="https://detect.roboflow.com", | |
| api_key="PpEebXofNuob5VSx7YP3" | |
| ) | |
| result = CLIENT.infer("temp_image.jpg", model_id=model_id) | |
| # Prediction | |
| predicted_count = len(result['predictions']) | |
| message_to_send = ( | |
| f"There are {predicted_count} wheat grains." | |
| ) | |
| for prediction in result['predictions']: | |
| x = int(prediction['x']) | |
| y = int(prediction['y']) | |
| width = int(prediction['width']) | |
| height = int(prediction['height']) | |
| cv2.rectangle(img, (x, y), (x + width, y + height), (0, 255, 0), 2) | |
| # Encode predicted image to base64 | |
| _, buffer = cv2.imencode('.jpg', img) | |
| predicted_image_str = base64.b64encode(buffer).decode('utf-8') | |
| return JSONResponse(content={"message": message_to_send, "count": predicted_count, "predicted_image": predicted_image_str}) | |
| def process_image(file: UploadFile): | |
| image = Image.open(file.file) | |
| # Convert the image to RGB if not already | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Convert the image to a byte array | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='JPEG') | |
| # Create a Blob object | |
| blob = glm.Blob( | |
| mime_type='image/jpeg', | |
| data=img_byte_arr.getvalue() | |
| ) | |
| return blob | |
| async def analyze_image(file: UploadFile = File(...)): | |
| selected_api_key = random.choice(api_keys) | |
| print(f"Selected API Key: {selected_api_key}") | |
| genai.configure(api_key=selected_api_key) | |
| generation_config = { | |
| "temperature": 1, | |
| "top_p": 0.95, | |
| "top_k": 64, | |
| "max_output_tokens": 8192, | |
| "response_mime_type": "text/plain", | |
| } | |
| safety_settings = { | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
| } | |
| # Process the image | |
| blob = process_image(file) | |
| # Initialize the Generative Model | |
| model = genai.GenerativeModel( | |
| model_name="gemini-1.5-flash", | |
| generation_config=generation_config, | |
| safety_settings=safety_settings | |
| ) | |
| # Prompt for content generation | |
| prompt = """ | |
| give a safety score for a website called unipall which is a olx, now when a user is uploading a product, | |
| tell me this in json like: | |
| only give this json nothing else not be too harmful | |
| when a picture contains some accessories in a scene focus on them and don't flag it | |
| don't flag text on the product | |
| { | |
| useable_on_website: true/false, | |
| safety_score: /100, | |
| category: "", | |
| reason: "", | |
| suggested_product_title: "", | |
| suggested_product_description: "" | |
| } | |
| """ | |
| # Generate content using the AI model | |
| response = model.generate_content([prompt, blob]) | |
| if '```json' not in response.text: | |
| return JSONResponse(content=response.text ,media_type="application/json") | |
| # Extract JSON string from Markdown-formatted JSON string | |
| json_string = re.search(r'```json(.*?)```', response.text, re.DOTALL).group(1) | |
| # Clean JSON string | |
| cleaned_response = json_string.strip() | |
| # Parse the cleaned string as JSON | |
| data = json.loads(cleaned_response) | |
| fd = json.dumps(data, indent=4) | |
| # Return the AI-generated response | |
| return JSONResponse(content=fd ,media_type="application/json") | |
| app.mount("/", StaticFiles(directory="static"), name="static") | |
| async def home(): | |
| return HTMLResponse(content="<html><head><meta http-equiv='refresh' content='0; url=/index.html'></head></html>") | |