ShahanMalik commited on
Commit
f7576cd
·
verified ·
1 Parent(s): baca7fd

Update TextGen/router.py

Browse files
Files changed (1) hide show
  1. TextGen/router.py +167 -31
TextGen/router.py CHANGED
@@ -1,35 +1,74 @@
1
  from pydantic import BaseModel
2
-
3
- from .ConfigEnv import config
4
  from fastapi.middleware.cors import CORSMiddleware
5
-
6
- from langchain_community.llms import Clarifai
7
- from langchain.chains import LLMChain
8
- from langchain.prompts import PromptTemplate
9
  from TextGen import app
10
 
11
- class Generate(BaseModel):
12
- text: str
13
-
14
- def generate_text(prompt: str):
15
- if prompt == "":
16
- return {"detail": "Please provide a prompt."}
17
- else:
18
- prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])
19
- llm = Clarifai(
20
- pat=config.CLARIFAI_PAT,
21
- user_id=config.USER_ID,
22
- app_id=config.APP_ID,
23
- model_id=config.MODEL_ID,
24
- model_version_id=config.MODEL_VERSION_ID,
25
- )
26
- llmchain = LLMChain(
27
- prompt=prompt,
28
- llm=llm
29
- )
30
- llm_response = llmchain.run({"Prompt": prompt})
31
- return Generate(text=llm_response)
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  app.add_middleware(
34
  CORSMiddleware,
35
  allow_origins=["*"],
@@ -40,8 +79,105 @@ app.add_middleware(
40
 
41
  @app.get("/", tags=["Home"])
42
  def api_home():
43
- return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}
 
 
 
 
 
 
 
 
 
44
 
45
- @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
46
- def inference(input_prompt: str):
47
- return generate_text(prompt=input_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pydantic import BaseModel
2
+ from fastapi import File, UploadFile, HTTPException
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ import numpy as np
5
+ from tensorflow.keras.models import load_model
6
+ from PIL import Image
7
+ import io
8
  from TextGen import app
9
 
10
+ # Response models
11
+ class ImageDetectionResponse(BaseModel):
12
+ is_ai_generated: bool
13
+ confidence_score: float
14
+ confidence_percentage: float
15
+ prediction_score: float
16
+ message: str
17
+
18
+ class ErrorResponse(BaseModel):
19
+ error: str
20
+ detail: str
21
+
22
+ # Global model variable
23
+ model = None
24
+
25
+ def load_ai_detection_model():
26
+ """Load the AI detection model"""
27
+ global model
28
+ if model is None:
29
+ try:
30
+ model = load_model('src/best_model.keras')
31
+ print("✅ Model loaded successfully")
32
+ except Exception as e:
33
+ print(f"❌ Error loading model: {str(e)}")
34
+ model = None
35
+ return model
36
 
37
+ def preprocess_image(image_file):
38
+ """Preprocess the uploaded image for model prediction"""
39
+ try:
40
+ # Read file bytes
41
+ file_bytes = image_file.read()
42
+
43
+ # Open image using PIL from bytes
44
+ img = Image.open(io.BytesIO(file_bytes))
45
+
46
+ # Convert to RGB if necessary
47
+ if img.mode != 'RGB':
48
+ img = img.convert('RGB')
49
+
50
+ # Resize to model's expected input size (300x300)
51
+ img = img.resize((300, 300), Image.Resampling.LANCZOS)
52
+
53
+ # Convert to array and normalize
54
+ img_array = np.array(img, dtype=np.float32) / 255.0
55
+
56
+ # Add batch dimension
57
+ img_array = np.expand_dims(img_array, axis=0)
58
+
59
+ return img_array
60
+ except Exception as e:
61
+ raise HTTPException(status_code=400, detail=f"Error preprocessing image: {str(e)}")
62
+
63
+ def predict_image(model, img_array):
64
+ """Make prediction on the preprocessed image"""
65
+ try:
66
+ prediction = model.predict(img_array, verbose=0)
67
+ return prediction
68
+ except Exception as e:
69
+ raise HTTPException(status_code=500, detail=f"Error making prediction: {str(e)}")
70
+
71
+ # Add CORS middleware
72
  app.add_middleware(
73
  CORSMiddleware,
74
  allow_origins=["*"],
 
79
 
80
  @app.get("/", tags=["Home"])
81
  def api_home():
82
+ return {
83
+ 'message': 'AI Image Detection API',
84
+ 'description': 'Upload an image to detect if it is AI-generated or real',
85
+ 'endpoints': {
86
+ 'POST /detect': 'Upload image for AI detection',
87
+ 'GET /health': 'Check API health status'
88
+ },
89
+ 'usage': 'Send POST request to /detect with image file',
90
+ 'supported_formats': ['JPG', 'JPEG', 'PNG', 'BMP', 'TIFF']
91
+ }
92
 
93
+ @app.get("/health", tags=["Health"])
94
+ def health_check():
95
+ """Health check endpoint"""
96
+ model_status = "loaded" if load_ai_detection_model() is not None else "not_loaded"
97
+ return {
98
+ 'status': 'healthy',
99
+ 'model_status': model_status,
100
+ 'message': 'AI Image Detection API is running'
101
+ }
102
+
103
+ @app.post("/detect",
104
+ summary="Detect if image is AI-generated",
105
+ tags=["Detection"],
106
+ response_model=ImageDetectionResponse,
107
+ responses={
108
+ 400: {"model": ErrorResponse, "description": "Bad Request"},
109
+ 500: {"model": ErrorResponse, "description": "Internal Server Error"}
110
+ })
111
+ async def detect_ai_image(file: UploadFile = File(...)):
112
+ """
113
+ Upload an image to detect if it's AI-generated or real.
114
+
115
+ - **file**: Image file (JPG, JPEG, PNG, BMP, TIFF)
116
+ - Returns: Detection result with confidence score
117
+ """
118
+
119
+ # Validate file type
120
+ if not file.content_type or not file.content_type.startswith('image/'):
121
+ raise HTTPException(
122
+ status_code=400,
123
+ detail="Invalid file type. Please upload an image file."
124
+ )
125
+
126
+ # Check file size (5MB limit)
127
+ file_size = 0
128
+ content = await file.read()
129
+ file_size = len(content)
130
+ max_size = 5 * 1024 * 1024 # 5MB
131
+
132
+ if file_size > max_size:
133
+ raise HTTPException(
134
+ status_code=400,
135
+ detail=f"File size ({file_size/1024/1024:.2f}MB) exceeds 5MB limit"
136
+ )
137
+
138
+ # Reset file pointer
139
+ await file.seek(0)
140
+
141
+ # Load model
142
+ detection_model = load_ai_detection_model()
143
+ if detection_model is None:
144
+ raise HTTPException(
145
+ status_code=500,
146
+ detail="AI detection model not available. Please try again later."
147
+ )
148
+
149
+ try:
150
+ # Preprocess image
151
+ img_array = preprocess_image(file.file)
152
+
153
+ # Make prediction
154
+ prediction = predict_image(detection_model, img_array)
155
+
156
+ # Process results
157
+ confidence_score = float(prediction[0][0])
158
+ threshold = 0.5
159
+
160
+ is_ai_generated = confidence_score > threshold
161
+
162
+ if is_ai_generated:
163
+ confidence_percentage = confidence_score * 100
164
+ message = "This image appears to be AI-generated"
165
+ else:
166
+ confidence_percentage = (1 - confidence_score) * 100
167
+ message = "This image appears to be real/human-made"
168
+
169
+ return ImageDetectionResponse(
170
+ is_ai_generated=is_ai_generated,
171
+ confidence_score=confidence_score,
172
+ confidence_percentage=round(confidence_percentage, 2),
173
+ prediction_score=round(confidence_score, 6),
174
+ message=message
175
+ )
176
+
177
+ except HTTPException:
178
+ raise
179
+ except Exception as e:
180
+ raise HTTPException(
181
+ status_code=500,
182
+ detail=f"Unexpected error during image processing: {str(e)}"
183
+ )