abdinkoo commited on
Commit
3be0a64
Β·
verified Β·
1 Parent(s): f280e6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -112
app.py CHANGED
@@ -1,150 +1,161 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- import tensorflow as tf
4
- import numpy as np
5
- from PIL import Image
 
 
6
  import io
7
- import uvicorn
8
- import tempfile
9
- import cv2
10
 
 
 
11
 
12
- # Initialize FastAPI app
13
- app = FastAPI(title="Plant Disease Detection API", version="1.0.0")
14
-
15
- # Add CORS middleware to allow requests from your frontend
16
  app.add_middleware(
17
  CORSMiddleware,
18
- allow_origins=["*"], # In production, replace with your frontend URL
19
- allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
- # Load your model
25
- try:
26
- model = tf.keras.models.load_model("trained_modela.keras")
27
- except Exception as e:
28
- raise RuntimeError(f"Failed to load model: {e}")
29
-
30
- # Define your class names (update with your actual classes)
31
- class_name = ['Apple___Apple_scab',
32
- 'Apple___Black_rot',
33
- 'Apple___Cedar_apple_rust',
34
- 'Apple___healthy',
35
- 'Blueberry___healthy',
36
- 'Cherry_(including_sour)___Powdery_mildew',
37
- 'Cherry_(including_sour)___healthy',
38
- 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
39
- 'Corn_(maize)___Common_rust_',
40
- 'Corn_(maize)___Northern_Leaf_Blight',
41
- 'Corn_(maize)___healthy',
42
- 'Grape___Black_rot',
43
- 'Grape___Esca_(Black_Measles)',
44
- 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
45
- 'Grape___healthy',
46
- 'Orange___Haunglongbing_(Citrus_greening)',
47
- 'Peach___Bacterial_spot',
48
- 'Peach___healthy',
49
- 'Pepper,_bell___Bacterial_spot',
50
- 'Pepper,_bell___healthy',
51
- 'Potato___Early_blight',
52
- 'Potato___Late_blight',
53
- 'Potato___healthy',
54
- 'Raspberry___healthy',
55
- 'Soybean___healthy',
56
- 'Squash___Powdery_mildew',
57
- 'Strawberry___Leaf_scorch',
58
- 'Strawberry___healthy',
59
- 'Tomato___Bacterial_spot',
60
- 'Tomato___Early_blight',
61
- 'Tomato___Late_blight',
62
- 'Tomato___Leaf_Mold',
63
- 'Tomato___Septoria_leaf_spot',
64
- 'Tomato___Spider_mites Two-spotted_spider_mite',
65
- 'Tomato___Target_Spot',
66
- 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
67
- 'Tomato___Tomato_mosaic_virus',
68
- 'Tomato___healthy']
69
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  @app.get("/")
72
- async def root():
73
- print("dfhkjfdshu")
74
- return {"message": "Plant Disease Detection API", "version": "1.0.0"}
75
 
76
- @app.post("/predict")
77
- async def predict_disease(file: UploadFile = File(...)):
78
-
79
- if not file.content_type.startswith('image/'):
80
- raise HTTPException(status_code=400, detail="File must be an image")
81
 
82
- try:
83
- # Validate file type
84
- # Validate file type
85
-
86
- # Save uploaded file temporarily
87
- with tempfile.NamedTemporaryFile(suffix=".jpeg", delete=False) as tmp:
88
- temp_path = tmp.name
89
- tmp.write(await file.read())
90
- tmp.flush() # Ensure data is written
91
-
92
- # Read image using OpenCV
93
- # img = cv2.imread(temp_path)
94
- # if img is None:
95
- # raise HTTPException(status_code=400, detail="Invalid image file")
96
- # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97
-
98
- image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128))
99
-
100
- input_arr = tf.keras.preprocessing.image.img_to_array(image)
101
- input_arr = np.array([input_arr]) # Convert single image to batch
102
-
103
- # Predict
104
- prediction = model.predict(input_arr)
105
- result_index = np.argmax(prediction)
106
- confidence = float(prediction[0][result_index])
107
- disease_name = class_name[result_index]
108
-
109
-
110
 
111
- return {
112
- "success": True,
113
- "disease": disease_name,
114
- "confidence": confidence
115
- }
116
 
117
-
118
-
119
- except Exception as e:
120
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
 
 
 
 
 
121
 
122
- @app.get("/health")
123
- async def health_check():
124
- return {"status": "healthy"}
125
 
126
- @app.get("/classes")
127
- async def get_classes():
128
- """Get all available disease classes"""
129
- return {"classes": class_name}
130
 
131
- if __name__ == "__main__":
132
- uvicorn.run(app, host="0.0.0.0", port=7860)
133
 
 
 
 
 
 
 
 
 
 
134
 
 
 
 
135
 
 
 
 
136
 
 
137
 
 
138
 
 
 
 
 
 
139
 
 
 
 
 
 
140
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
142
 
 
143
 
 
144
 
 
 
 
 
145
 
 
 
146
 
 
147
 
 
 
148
 
 
 
149
 
 
150
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, File, UploadFile
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import google.generativeai as genai
6
+ import pdfplumber
7
+ import pickle
8
  import io
9
+ import os
10
+ from datetime import datetime
 
11
 
12
+ # ================= APP =================
13
+ app = FastAPI(title="Jimma University Plagiarism API")
14
 
 
 
 
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
+ allow_origins=["*"],
 
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
+ # ================= CONFIG =================
23
+ MODEL_PATH = "plagiarism_model"
24
+ EMBEDDINGS_FILE = "reference_embeddings.pkl"
25
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_KEY_HERE")
26
+
27
+ SIMILARITY_THRESHOLD = 30.0
28
+
29
+ # ================= LOAD SBERT MODEL =================
30
+ model = SentenceTransformer(MODEL_PATH)
31
+ print("βœ… Model loaded:", MODEL_PATH)
32
+
33
+ # ================= LOAD REFERENCE DATA =================
34
+ with open(EMBEDDINGS_FILE, "rb") as f:
35
+ data = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ ref_embeddings = data["embeddings"]
38
+ df_ref = data["df_ref"]
39
 
40
+ print("βœ… Reference dataset loaded")
41
+
42
+ # ================= GEMINI =================
43
+ genai.configure(api_key=GEMINI_API_KEY)
44
+ gemini_model = genai.GenerativeModel("gemini-2.5-flash")
45
+
46
+ # ================= REQUEST MODEL =================
47
+ class PlagiarismRequest(BaseModel):
48
+ text: str
49
+ title: str = "Unknown"
50
+ student_name: str = "Unknown"
51
+ year: str = "2026"
52
+
53
+ # ================= HEALTH CHECK =================
54
  @app.get("/")
55
+ def home():
56
+ return {"message": "Plagiarism API Running πŸš€"}
 
57
 
58
+ # ================= TEXT CHECK API =================
59
+ @app.post("/check_plagiarism")
60
+ async def check_plagiarism(req: PlagiarismRequest):
 
 
61
 
62
+ text = req.text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ if len(text) < 100:
65
+ raise HTTPException(400, "Text too short")
 
 
 
66
 
67
+ if len(text) > 8000:
68
+ text = text[:8000]
69
+
70
+ # ================= SBERT =================
71
+ query_embedding = model.encode(
72
+ text,
73
+ convert_to_tensor=True,
74
+ normalize_embeddings=True
75
+ )
76
 
77
+ scores = util.cos_sim(query_embedding, ref_embeddings)[0]
78
+ scores = (scores * 100).cpu().numpy()
 
79
 
80
+ top_idx = int(scores.argmax())
81
+ top_score = float(scores[top_idx])
 
 
82
 
83
+ row = df_ref.iloc[top_idx]
 
84
 
85
+ # ================= LOW RISK =================
86
+ if top_score < SIMILARITY_THRESHOLD:
87
+ return {
88
+ "status": "low_risk",
89
+ "similarity_percent": round(top_score, 2),
90
+ "rating": 1,
91
+ "most_similar_source": str(row.get("title", "N/A")),
92
+ "message": "No significant plagiarism detected"
93
+ }
94
 
95
+ # ================= GEMINI REPORT =================
96
+ prompt = f"""
97
+ You are an academic plagiarism expert.
98
 
99
+ Title: {req.title}
100
+ Student: {req.student_name}
101
+ Year: {req.year}
102
 
103
+ Similarity: {top_score:.2f}%
104
 
105
+ Source: {row.get("title", "N/A")}
106
 
107
+ Give:
108
+ 1. Similarity explanation
109
+ 2. Risk level
110
+ 3. Recommendation
111
+ """
112
 
113
+ try:
114
+ response = gemini_model.generate_content(prompt)
115
+ report = response.text
116
+ except Exception as e:
117
+ report = f"Gemini error: {str(e)}"
118
 
119
+ # ================= RESPONSE =================
120
+ return {
121
+ "status": "suspicious",
122
+ "similarity_percent": round(top_score, 2),
123
+ "rating": 4 if top_score > 70 else 3,
124
+ "stars": "β˜…β˜…β˜…β˜…β˜†" if top_score > 70 else "β˜…β˜…β˜…β˜†β˜†",
125
+ "most_similar_source": str(row.get("title", "N/A")),
126
+ "source_student": str(row.get("student_name", "N/A")),
127
+ "gemini_report": report,
128
+ "timestamp": datetime.now().isoformat()
129
+ }
130
 
131
+ # ================= PDF UPLOAD API (OPTIONAL) =================
132
+ @app.post("/check_pdf")
133
+ async def check_pdf(file: UploadFile = File(...)):
134
 
135
+ content = await file.read()
136
 
137
+ text = ""
138
 
139
+ with pdfplumber.open(io.BytesIO(content)) as pdf:
140
+ for page in pdf.pages:
141
+ if page.extract_text():
142
+ text += page.extract_text() + "\n"
143
 
144
+ if len(text) < 100:
145
+ return {"error": "PDF too short"}
146
 
147
+ query_embedding = model.encode(text, convert_to_tensor=True, normalize_embeddings=True)
148
 
149
+ scores = util.cos_sim(query_embedding, ref_embeddings)[0]
150
+ scores = (scores * 100).cpu().numpy()
151
 
152
+ top_idx = int(scores.argmax())
153
+ top_score = float(scores[top_idx])
154
 
155
+ row = df_ref.iloc[top_idx]
156
 
157
+ return {
158
+ "status": "done",
159
+ "similarity_percent": round(top_score, 2),
160
+ "best_match": str(row.get("title", "N/A"))
161
+ }