abdinkoo commited on
Commit
2193e51
Β·
verified Β·
1 Parent(s): f280e6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -114
app.py CHANGED
@@ -1,150 +1,190 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- import tensorflow as tf
4
- import numpy as np
5
- from PIL import Image
6
- import io
7
- import uvicorn
8
- import tempfile
9
- import cv2
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Initialize FastAPI app
13
- app = FastAPI(title="Plant Disease Detection API", version="1.0.0")
 
 
14
 
15
- # Add CORS middleware to allow requests from your frontend
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["*"], # In production, replace with your frontend URL
19
- allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
- # Load your model
25
- try:
26
- model = tf.keras.models.load_model("trained_modela.keras")
27
- except Exception as e:
28
- raise RuntimeError(f"Failed to load model: {e}")
29
-
30
- # Define your class names (update with your actual classes)
31
- class_name = ['Apple___Apple_scab',
32
- 'Apple___Black_rot',
33
- 'Apple___Cedar_apple_rust',
34
- 'Apple___healthy',
35
- 'Blueberry___healthy',
36
- 'Cherry_(including_sour)___Powdery_mildew',
37
- 'Cherry_(including_sour)___healthy',
38
- 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
39
- 'Corn_(maize)___Common_rust_',
40
- 'Corn_(maize)___Northern_Leaf_Blight',
41
- 'Corn_(maize)___healthy',
42
- 'Grape___Black_rot',
43
- 'Grape___Esca_(Black_Measles)',
44
- 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
45
- 'Grape___healthy',
46
- 'Orange___Haunglongbing_(Citrus_greening)',
47
- 'Peach___Bacterial_spot',
48
- 'Peach___healthy',
49
- 'Pepper,_bell___Bacterial_spot',
50
- 'Pepper,_bell___healthy',
51
- 'Potato___Early_blight',
52
- 'Potato___Late_blight',
53
- 'Potato___healthy',
54
- 'Raspberry___healthy',
55
- 'Soybean___healthy',
56
- 'Squash___Powdery_mildew',
57
- 'Strawberry___Leaf_scorch',
58
- 'Strawberry___healthy',
59
- 'Tomato___Bacterial_spot',
60
- 'Tomato___Early_blight',
61
- 'Tomato___Late_blight',
62
- 'Tomato___Leaf_Mold',
63
- 'Tomato___Septoria_leaf_spot',
64
- 'Tomato___Spider_mites Two-spotted_spider_mite',
65
- 'Tomato___Target_Spot',
66
- 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
67
- 'Tomato___Tomato_mosaic_virus',
68
- 'Tomato___healthy']
69
 
 
 
 
70
 
71
- @app.get("/")
72
- async def root():
73
- print("dfhkjfdshu")
74
- return {"message": "Plant Disease Detection API", "version": "1.0.0"}
75
 
76
- @app.post("/predict")
77
- async def predict_disease(file: UploadFile = File(...)):
78
-
79
- if not file.content_type.startswith('image/'):
80
- raise HTTPException(status_code=400, detail="File must be an image")
81
 
82
- try:
83
- # Validate file type
84
- # Validate file type
85
-
86
- # Save uploaded file temporarily
87
- with tempfile.NamedTemporaryFile(suffix=".jpeg", delete=False) as tmp:
88
- temp_path = tmp.name
89
- tmp.write(await file.read())
90
- tmp.flush() # Ensure data is written
91
-
92
- # Read image using OpenCV
93
- # img = cv2.imread(temp_path)
94
- # if img is None:
95
- # raise HTTPException(status_code=400, detail="Invalid image file")
96
- # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97
-
98
- image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128))
99
-
100
- input_arr = tf.keras.preprocessing.image.img_to_array(image)
101
- input_arr = np.array([input_arr]) # Convert single image to batch
102
-
103
- # Predict
104
- prediction = model.predict(input_arr)
105
- result_index = np.argmax(prediction)
106
- confidence = float(prediction[0][result_index])
107
- disease_name = class_name[result_index]
108
-
109
-
110
 
111
- return {
112
- "success": True,
113
- "disease": disease_name,
114
- "confidence": confidence
115
- }
116
 
117
-
118
-
119
- except Exception as e:
120
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
121
 
122
- @app.get("/health")
123
- async def health_check():
124
- return {"status": "healthy"}
125
 
126
- @app.get("/classes")
127
- async def get_classes():
128
- """Get all available disease classes"""
129
- return {"classes": class_name}
130
 
131
- if __name__ == "__main__":
132
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
133
 
 
 
134
 
 
 
 
 
 
 
 
135
 
 
 
 
136
 
 
137
 
 
 
138
 
 
 
139
 
 
 
 
 
 
 
 
140
 
 
 
141
 
 
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
 
 
144
 
 
 
 
145
 
 
146
 
 
 
 
147
 
 
 
 
148
 
 
149
 
 
 
 
 
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import pickle
6
+ import google.generativeai as genai
7
+ from datetime import datetime
8
+ from typing import Dict
9
+ import os
10
+
11
+ app = FastAPI(title="Jimma University Plagiarism API")
12
+
13
+ # ====================== SAFE LIMITS ======================
14
+ MAX_TEXT_LENGTH = 8000
15
+ MAX_PROMPT_LENGTH = 4000
16
+ SIMILARITY_THRESHOLD = 30.0
17
+ # =========================================================
18
+
19
+ # ====================== RATING FUNCTION ======================
20
+ def convert_to_rating(similarity_percent: float) -> int:
21
+ if similarity_percent >= 80:
22
+ return 5
23
+ elif similarity_percent >= 60:
24
+ return 4
25
+ elif similarity_percent >= 40:
26
+ return 3
27
+ elif similarity_percent >= 20:
28
+ return 2
29
+ else:
30
+ return 1
31
+ # ============================================================
32
+
33
+ # ====================== ROOT ======================
34
+ @app.get("/")
35
+ def home():
36
+ return {"message": "Jimma University Plagiarism API is running πŸš€"}
37
 
38
+ @app.get("/health")
39
+ def health():
40
+ return {"status": "ok"}
41
+ # ==================================================
42
 
 
43
  app.add_middleware(
44
  CORSMiddleware,
45
+ allow_origins=["*"],
 
46
  allow_methods=["*"],
47
  allow_headers=["*"],
48
  )
49
 
50
+ # ====================== CONFIG ======================
51
+ GEMINI_API_KEY = os.getenv(
52
+ "GEMINI_API_KEY",
53
+ "AQ.Ab8RN6Id1IlRKgMi19Vmy7PGrY82ZxG5D34vsDOnsFOFdrRI6g"
54
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ MODEL_PATH = "plagiarism_sbert_model"
57
+ EMBEDDINGS_FILE = "reference_embeddings.pkl"
58
+ # ===================================================
59
 
60
+ # ====================== LOAD MODEL (FIXED) ======================
61
+ if not os.path.exists(MODEL_PATH):
62
+ raise RuntimeError("❌ Model folder not found")
 
63
 
64
+ model = SentenceTransformer(MODEL_PATH)
 
 
 
 
65
 
66
+ print("βœ… SBERT model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # ====================== LOAD REFERENCE DATASET ======================
69
+ if not os.path.exists(EMBEDDINGS_FILE):
70
+ raise RuntimeError("❌ Reference embeddings file not found")
 
 
71
 
72
+ with open(EMBEDDINGS_FILE, "rb") as f:
73
+ data = pickle.load(f)
 
 
74
 
75
+ ref_embeddings = data["embeddings"]
76
+ df_ref = data["df_ref"]
 
77
 
78
+ print("βœ… Reference dataset loaded")
79
+ # ================================================================
 
 
80
 
81
+ # ====================== GEMINI ======================
82
+ genai.configure(api_key=GEMINI_API_KEY)
83
+ gemini_model = genai.GenerativeModel('gemini-2.5-flash')
84
 
85
+ print("βœ… System ready")
86
+ # ===================================================
87
 
88
+ # ====================== REQUEST MODEL ======================
89
+ class PlagiarismRequest(BaseModel):
90
+ text: str
91
+ title: str = "Submitted Document"
92
+ student_name: str = "Unknown Student"
93
+ year: str = "2026"
94
+ # ============================================================
95
 
96
+ # ====================== API ======================
97
+ @app.post("/check_plagiarism")
98
+ async def check_plagiarism(req: PlagiarismRequest) -> Dict:
99
 
100
+ text = req.text.strip()
101
 
102
+ if len(text) < 200:
103
+ raise HTTPException(400, "Text too short (minimum 200 characters)")
104
 
105
+ if len(text) > MAX_TEXT_LENGTH:
106
+ text = text[:MAX_TEXT_LENGTH]
107
 
108
+ # ================= SBERT ENCODING =================
109
+ try:
110
+ query_embedding = model.encode(
111
+ text,
112
+ convert_to_tensor=True,
113
+ normalize_embeddings=True
114
+ )
115
 
116
+ # IMPORTANT FIX: cosine similarity stays in 0–1 range
117
+ cosine_scores = util.cos_sim(query_embedding, ref_embeddings)[0]
118
 
119
+ # convert to percentage properly
120
+ similarities = (cosine_scores * 100).cpu().numpy()
121
 
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=f"Embedding error: {str(e)}")
124
+
125
+ # ================= TOP MATCH =================
126
+ top_idx = int(similarities.argmax())
127
+ top_similarity = float(similarities[top_idx])
128
+
129
+ rating = convert_to_rating(top_similarity)
130
+ stars = "β˜…" * rating + "β˜†" * (5 - rating)
131
+
132
+ # ================= LOW RISK =================
133
+ if top_similarity <= SIMILARITY_THRESHOLD:
134
+ return {
135
+ "status": "low_risk",
136
+ "similarity_percent": round(top_similarity, 2),
137
+ "rating": rating,
138
+ "stars": stars,
139
+ "message": "No significant plagiarism detected."
140
+ }
141
+
142
+ # ================= SOURCE =================
143
+ row = df_ref.iloc[top_idx]
144
+ source_title = str(row.get("title", "Reference Project"))[:150]
145
+ source_student = str(row.get("student_name", "Original Student"))
146
+ source_year = str(row.get("year", "2023"))
147
 
148
+ category = "LOW" if top_similarity <= 30 else "MEDIUM" if top_similarity <= 70 else "HIGH"
149
+ emoji = "βœ…" if category == "LOW" else "βš–οΈ" if category == "MEDIUM" else "❌"
150
 
151
+ # ================= GEMINI PROMPT =================
152
+ prompt = f"""
153
+ You are a strict academic plagiarism supervisor at Jimma University.
154
 
155
+ {emoji} {category} SIMILARITY CASE
156
 
157
+ Source Title: {source_title}
158
+ Student Name: {source_student}
159
+ Year: {source_year}
160
 
161
+ Suspicious Title: {req.title}
162
+ Student Name: {req.student_name}
163
+ Year: {req.year}
164
 
165
+ Similarity Score: {top_similarity:.1f}%
166
 
167
+ 1. Conceptual Similarity:
168
+ 2. Conceptual Differences:
169
+ 3. Technology Differences:
170
+ 4. Supervisor Recommendation:
171
+ """
172
 
173
+ try:
174
+ prompt = prompt[:MAX_PROMPT_LENGTH]
175
+ response = gemini_model.generate_content(prompt)
176
+ report = response.text.strip()
177
+
178
+ except Exception as e:
179
+ report = f"Gemini error: {str(e)}"
180
+
181
+ return {
182
+ "status": "suspicious",
183
+ "similarity_percent": round(top_similarity, 2),
184
+ "rating": rating,
185
+ "stars": stars,
186
+ "most_similar_source": source_title,
187
+ "source_student": source_student,
188
+ "gemini_report": report,
189
+ "timestamp": datetime.now().isoformat()
190
+ }