lydiasolomon commited on
Commit
83af178
·
verified ·
1 Parent(s): 5fbb27d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +91 -280
main.py CHANGED
@@ -1,58 +1,33 @@
1
- # main.py
2
  import os
3
- import tempfile
4
  import logging
5
- import traceback
6
- from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body, Request
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
  from transformers import pipeline
10
- from langdetect import detect, DetectorFactory
11
  from PIL import Image
12
- import io
13
- from smebuilder_vector import retriever # your existing retriever module
14
- import spitch
15
 
16
  # ==============================
17
  # Logging Setup
18
  # ==============================
19
  logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger("DevAssist")
21
-
22
- # Debug log file for prompts + outputs
23
- DEBUG_LOG_FILE = os.getenv("LLM_DEBUG_LOG", "llm_debug.log")
24
-
25
- # ==============================
26
- # App Init
27
- # ==============================
28
- app = FastAPI(title="DevAssist / CuraAI Backend")
29
 
30
  # ==============================
31
- # Config
32
  # ==============================
33
- DetectorFactory.seed = 0
34
- PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
35
- SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
36
-
37
- # Models chosen per task (public/reasonable defaults)
38
- HF_MODELS = {
39
- "chat": os.getenv("CHAT_MODEL", "bigcode/starcoderbase"), # coding assistant
40
- "autodoc": os.getenv("AUTODOC_MODEL", "Salesforce/codegen-2B-mono"), # code -> docs
41
- "sme": os.getenv("SME_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct"), # frontend generation
42
- "image_caption": os.getenv("IMAGE_CAPTION_MODEL", "Salesforce/blip-image-captioning-base")
43
- }
44
-
45
- if not SPITCH_API_KEY:
46
- raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
47
 
48
- # Initialize Spitch client once
49
- spitch_client = spitch.Spitch()
50
- # Optionally set env var for Spitch API if required by client library
51
- os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
52
 
53
  # ==============================
54
- # Authentication helper
55
  # ==============================
 
 
56
  def check_auth(authorization: str | None):
57
  if not PROJECT_API_KEY:
58
  return
@@ -63,289 +38,125 @@ def check_auth(authorization: str | None):
63
  raise HTTPException(status_code=403, detail="Invalid token")
64
 
65
  # ==============================
66
- # Global exception handler
67
  # ==============================
68
  @app.exception_handler(Exception)
69
  async def global_exception_handler(request: Request, exc: Exception):
70
- logger.error(f"Unhandled error: {exc}", exc_info=True)
71
  return JSONResponse(status_code=500, content={"error": str(exc)})
72
 
73
  # ==============================
74
- # Request models
75
  # ==============================
76
  class ChatRequest(BaseModel):
77
- question: str
78
 
79
- class AutoDocRequest(BaseModel):
80
- code: str
81
 
82
- class SMERequest(BaseModel):
83
- user_prompt: str
84
 
85
- # For simple vector search API
86
  class VectorRequest(BaseModel):
87
  query: str
88
 
89
  # ==============================
90
- # Pipeline loader with fallback
91
  # ==============================
92
- def load_pipeline(task: str, model_name: str, fallback: str = None):
93
- """
94
- Load a HuggingFace pipeline with a fallback option.
95
- Keep the load minimal (no device_map here — set in env for production).
96
- """
97
- try:
98
- logger.info(f"Loading pipeline task={task} model={model_name}")
99
- return pipeline(task, model=model_name)
100
- except Exception as e:
101
- logger.warning(f"Failed to load {model_name} for task={task}: {e}")
102
- if fallback:
103
- logger.info(f"Falling back to {fallback} for task={task}")
104
- return pipeline(task, model=fallback)
105
- raise
106
 
107
- # ==============================
108
- # Pipelines (load on startup)
109
- # ==============================
110
- # text-generation pipelines for chat/autodoc/sme
111
- chat_pipe = load_pipeline("text-generation", HF_MODELS["chat"], fallback="gpt2")
112
- autodoc_pipe = load_pipeline("text-generation", HF_MODELS["autodoc"], fallback="gpt2")
113
- sme_pipe = load_pipeline("text-generation", HF_MODELS["sme"], fallback="gpt2")
 
114
 
115
- # image caption / image-to-text pipeline for crop/vision tasks
116
- image_caption_pipe = load_pipeline("image-to-text", HF_MODELS["image_caption"], fallback="Salesforce/blip-image-captioning-base")
 
 
 
 
 
 
 
 
117
 
118
  # ==============================
119
- # Helper / wrapper functions
120
  # ==============================
121
- def debug_log_prompt(prompt: str, output: str, tag: str = "LLM"):
122
- try:
123
- with open(DEBUG_LOG_FILE, "a", encoding="utf-8") as fh:
124
- fh.write(f"=== {tag} PROMPT START ===\n")
125
- fh.write(prompt + "\n")
126
- fh.write("--- MODEL OUTPUT ---\n")
127
- fh.write(output + "\n")
128
- fh.write(f"=== {tag} PROMPT END ===\n\n")
129
- except Exception:
130
- logger.exception("Failed to write debug log")
131
-
132
- def run_pipeline(pipe, prompt: str, max_new_tokens: int = 1024):
133
- """
134
- Run a text-generation pipeline and return text or structured error.
135
- Logs prompt + output to debug file.
136
- """
137
  try:
138
- # call pipeline (many models return list with 'generated_text')
139
- output_list = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True)
140
- text = ""
141
- if isinstance(output_list, list) and len(output_list) > 0:
142
- # handle generators that include 'generated_text'
143
- first = output_list[0]
144
- if isinstance(first, dict) and "generated_text" in first:
145
- text = first["generated_text"]
146
- else:
147
- text = str(first)
148
- else:
149
- text = str(output_list)
150
-
151
- text = text.strip()
152
- debug_log_prompt(prompt, text, tag="TEXT-GEN")
153
- logger.info("Prompt executed successfully")
154
-
155
- if not text:
156
- return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt}
157
- return text
158
  except Exception as e:
159
- logger.error("Pipeline execution error", exc_info=True)
160
- trace = traceback.format_exc()
161
- debug_log_prompt(prompt, f"EXCEPTION:\n{trace}", tag="TEXT-GEN")
162
- return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": trace, "prompt": prompt}
163
 
164
- def run_image_to_text(pipe, image_bytes: bytes, prompt: str):
165
  """
166
- Run image-to-text pipelines (image captioning / multimodal).
167
- Returns generated_text or error structure.
168
  """
 
 
169
  try:
170
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
171
- output_list = pipe(image, prompt=prompt)
172
- text = ""
173
- if isinstance(output_list, list) and len(output_list) > 0 and isinstance(output_list[0], dict):
174
- text = output_list[0].get("generated_text", "")
175
- else:
176
- text = str(output_list)
177
- text = text.strip()
178
- debug_log_prompt(prompt, text, tag="IMG-TO-TEXT")
179
- if not text:
180
- return {"success": False, "error": "⚠️ Vision model returned empty output", "prompt": prompt}
181
- return text
182
  except Exception as e:
183
- logger.exception("Image-to-text pipeline error")
184
- trace = traceback.format_exc()
185
- debug_log_prompt(prompt, f"EXCEPTION:\n{trace}", tag="IMG-TO-TEXT")
186
- return {"success": False, "error": f"⚠️ Vision model error: {str(e)}", "trace": trace, "prompt": prompt}
187
 
188
  # ==============================
189
- # Audio processing (Spitch) helper
190
  # ==============================
191
- async def process_audio(file: UploadFile, lang_hint: str | None = None):
192
- """
193
- Save audio temporarily, transcribe via Spitch client, detect language and optionally translate to English.
194
- Returns (transcription, detected_lang, translation)
195
- """
196
- suffix = os.path.splitext(file.filename)[1] or ".wav"
197
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
198
- tf.write(await file.read())
199
- tmp_path = tf.name
200
-
201
- with open(tmp_path, "rb") as f:
202
- audio_bytes = f.read()
203
-
204
- try:
205
- if lang_hint:
206
- resp = spitch_client.speech.transcribe(language=lang_hint, content=audio_bytes)
207
- else:
208
- resp = spitch_client.speech.transcribe(content=audio_bytes)
209
- except Exception:
210
- # fallback to english if Spitch fails with the given hint
211
- resp = spitch_client.speech.transcribe(language="en", content=audio_bytes)
212
-
213
- transcription = getattr(resp, "text", "") or (resp.get("text", "") if isinstance(resp, dict) else "")
214
- detected_lang = "en"
215
- try:
216
- detected_lang = detect(transcription) if transcription.strip() else "en"
217
- except Exception:
218
- detected_lang = "en"
219
-
220
- translation = transcription
221
- if detected_lang != "en":
222
- try:
223
- translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
224
- translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "") or transcription
225
- except Exception:
226
- translation = transcription
227
-
228
- return transcription, detected_lang, translation
229
-
230
- # ==============================
231
- # Endpoints
232
- # ==============================
233
- @app.get("/")
234
- async def root_endpoint():
235
- return {"status": "✅ DevAssist / CuraAI Backend running"}
236
-
237
- # ----- Chat: coding assistant -----
238
- @app.post("/chat")
239
- async def chat_endpoint(req: ChatRequest, authorization: str | None = Header(None)):
240
  check_auth(authorization)
241
- # prompt template tuned for coding Q&A
242
- prompt = (
243
- "You are DevAssist — a helpful, concise coding assistant. "
244
- f"Answer clearly with code samples if relevant.\n\nQuestion:\n{req.question}\n\nAnswer:"
245
- )
246
- result = run_pipeline(chat_pipe, prompt, max_new_tokens=512)
247
- return result if isinstance(result, dict) else {"reply": result}
248
 
249
- # ----- Autodoc: code -> documentation -----
250
- @app.post("/autodoc")
251
- async def autodoc_endpoint(req: AutoDocRequest, authorization: str | None = Header(None)):
252
  check_auth(authorization)
253
- prompt = (
254
- "You are DevAssist DocBot. Produce professional Markdown documentation for the provided code.\n\n"
255
- f"Code:\n{req.code}\n\nDocumentation:"
256
- )
257
- result = run_pipeline(autodoc_pipe, prompt, max_new_tokens=512)
258
- return result if isinstance(result, dict) else {"documentation": result}
259
 
260
- # ----- SME: production-ready frontend generation (with retriever context) -----
261
- @app.post("/sme/generate")
262
- async def sme_generate_endpoint(req: SMERequest, authorization: str | None = Header(None)):
263
  check_auth(authorization)
264
- try:
265
- # Use retriever for context injection (keep old method for compatibility)
266
- try:
267
- context_docs = retriever.get_relevant_documents(req.user_prompt)
268
- except AttributeError:
269
- # if newer retriever API uses .invoke
270
- context_docs = retriever.invoke(req.user_prompt)
271
-
272
- context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
273
- prompt = (
274
- "You are a senior full-stack engineer. "
275
- "Generate production-ready frontend code (index.html, styles.css, script.js) "
276
- f"based on the prompt:\n{req.user_prompt}\n\nContext:\n{context}\n\nOutput:"
277
- )
278
- result = run_pipeline(sme_pipe, prompt, max_new_tokens=1500)
279
- return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
280
- except Exception as e:
281
- logger.exception("SME generate endpoint error")
282
- return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": traceback.format_exc()}
283
 
284
- # ----- SME Speech generate: STT -> SME -----
285
- @app.post("/sme/speech-generate")
286
- async def sme_speech_endpoint(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
287
- check_auth(authorization)
288
- transcription, detected_lang, translation = await process_audio(file, lang_hint)
289
- try:
290
- try:
291
- context_docs = retriever.get_relevant_documents(translation)
292
- except AttributeError:
293
- context_docs = retriever.invoke(translation)
294
-
295
- context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
296
- prompt = (
297
- "You are a senior full-stack engineer. Generate production-ready frontend code "
298
- f"based on the prompt:\n{translation}\n\nContext:\n{context}\n\nOutput:"
299
- )
300
- result = run_pipeline(sme_pipe, prompt, max_new_tokens=1500)
301
- return {
302
- "success": True,
303
- "transcription": transcription,
304
- "detected_language": detected_lang,
305
- "translation": translation,
306
- "sme_site": result if isinstance(result, str) else result.get("reply", "")
307
- }
308
- except Exception as e:
309
- logger.exception("SME speech-generate error")
310
- return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": traceback.format_exc()}
311
-
312
- # ----- Vision/crop doctor style endpoint (image + text -> diagnosis / explanation) -----
313
- @app.post("/vision/diagnose")
314
- async def vision_diagnose(symptoms: str = Header(...), image: UploadFile = File(...), authorization: str | None = Header(None)):
315
- """
316
- Use an image-to-text model (BLIP) to analyze an image + farmer description, then produce
317
- a simple diagnosis & treatment plan. Returns a string or error object.
318
- """
319
- check_auth(authorization)
320
- image_bytes = await image.read()
321
- prompt = (
322
- f"Farmer reports: {symptoms}. Analyze this plant image, diagnose the likely disease, "
323
- "provide simple treatment steps and short prevention advice in plain language."
324
- )
325
- result = run_image_to_text(image_caption_pipe, image_bytes, prompt)
326
- return {"diagnosis": result} if isinstance(result, str) else result
327
-
328
- # ----- Vector search wrapper endpoint -----
329
  @app.post("/vector-search")
330
  async def vector_search(req: VectorRequest, authorization: str | None = Header(None)):
331
  check_auth(authorization)
332
  try:
333
- # call your existing vector query function in smebuilder_vector (query_vector)
334
- try:
335
- results = retriever.get_relevant_documents(req.query)
336
- except AttributeError:
337
- # fallback to invoke if retriever API differs
338
- results = retriever.invoke(req.query)
339
- # normalize a simple list response
340
- brief = [{"page_content": getattr(r, "page_content", str(r)), "meta": getattr(r, "metadata", {})} for r in results]
341
- return {"results": brief}
342
  except Exception as e:
343
- logger.exception("Vector search error")
344
- return {"error": f"Vector search error: {str(e)}", "trace": traceback.format_exc()}
345
-
346
- # ==============================
347
- # Run App
348
- # ==============================
349
- if __name__ == "__main__":
350
- import uvicorn
351
- uvicorn.run("main:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)
 
 
 
 
 
 
1
  import os
 
2
  import logging
3
+ import io
4
+ from fastapi import FastAPI, Request, Header, HTTPException, UploadFile, File
5
  from fastapi.responses import JSONResponse
6
  from pydantic import BaseModel
7
  from transformers import pipeline
 
8
  from PIL import Image
9
+ from vector import query_vector
 
 
10
 
11
  # ==============================
12
  # Logging Setup
13
  # ==============================
14
  logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger("AgriCopilot")
 
 
 
 
 
 
 
 
16
 
17
  # ==============================
18
+ # App Initialization
19
  # ==============================
20
+ app = FastAPI(title="AgriCopilot AI API", version="2.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ @app.get("/")
23
+ async def root():
24
+ return {"status": "AgriCopilot AI Backend is running smoothly ✅"}
 
25
 
26
  # ==============================
27
+ # AUTH CONFIGURATION
28
  # ==============================
29
+ PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "agricopilot404")
30
+
31
  def check_auth(authorization: str | None):
32
  if not PROJECT_API_KEY:
33
  return
 
38
  raise HTTPException(status_code=403, detail="Invalid token")
39
 
40
  # ==============================
41
+ # Exception Handling
42
  # ==============================
43
  @app.exception_handler(Exception)
44
  async def global_exception_handler(request: Request, exc: Exception):
45
+ logger.error(f"Unhandled error: {exc}")
46
  return JSONResponse(status_code=500, content={"error": str(exc)})
47
 
48
  # ==============================
49
+ # Request Models
50
  # ==============================
51
  class ChatRequest(BaseModel):
52
+ query: str
53
 
54
+ class DisasterRequest(BaseModel):
55
+ report: str
56
 
57
+ class MarketRequest(BaseModel):
58
+ product: str
59
 
 
60
  class VectorRequest(BaseModel):
61
  query: str
62
 
63
  # ==============================
64
+ # Load Hugging Face Pipelines
65
  # ==============================
66
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ if not HF_TOKEN:
69
+ logger.warning("⚠️ No Hugging Face token found. Gated models may fail.")
70
+ else:
71
+ logger.info("✅ Hugging Face token loaded successfully.")
72
+
73
+ # General text-generation model for chat, disaster, market endpoints
74
+ default_model = "meta-llama/Llama-3.1-8B-Instruct"
75
+ vision_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
76
 
77
+ chat_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
78
+ disaster_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
79
+ market_pipe = pipeline("text-generation", model=default_model, token=HF_TOKEN)
80
+
81
+ # Multimodal crop diagnostic model
82
+ try:
83
+ crop_pipe = pipeline("image-text-to-text", model=vision_model, token=HF_TOKEN)
84
+ except Exception as e:
85
+ logger.warning(f"Crop model load failed: {e}")
86
+ crop_pipe = None
87
 
88
  # ==============================
89
+ # Helper Functions
90
  # ==============================
91
+ def run_conversational(pipe, prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
+ output = pipe(prompt, max_new_tokens=200)
94
+ if isinstance(output, list) and len(output) > 0:
95
+ return output[0].get("generated_text", str(output))
96
+ return str(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
+ logger.error(f"Pipeline error: {e}")
99
+ return f"⚠️ Model error: {str(e)}"
 
 
100
 
101
+ def run_crop_doctor(image_bytes: bytes, symptoms: str):
102
  """
103
+ Diagnose crop issues using Meta's multimodal LLaMA Vision model.
 
104
  """
105
+ if not crop_pipe:
106
+ return "⚠️ Crop analysis temporarily unavailable (model not loaded)."
107
  try:
108
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
109
+ prompt = (
110
+ f"The farmer reports: {symptoms}. "
111
+ "Analyze the image and diagnose the likely crop disease. "
112
+ "Then explain it simply and recommend possible treatment steps."
113
+ )
114
+ output = crop_pipe(image, prompt)
115
+ if isinstance(output, list) and len(output) > 0:
116
+ return output[0].get("generated_text", str(output))
117
+ return str(output)
 
 
118
  except Exception as e:
119
+ logger.error(f"Crop Doctor pipeline error: {e}")
120
+ return f"⚠️ Unexpected model error: {str(e)}"
 
 
121
 
122
  # ==============================
123
+ # API ROUTES
124
  # ==============================
125
+ @app.post("/multilingual-chat")
126
+ async def multilingual_chat(req: ChatRequest, authorization: str | None = Header(None)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  check_auth(authorization)
128
+ reply = run_conversational(chat_pipe, req.query)
129
+ return {"reply": reply}
 
 
 
 
 
130
 
131
+ @app.post("/disaster-summarizer")
132
+ async def disaster_summarizer(req: DisasterRequest, authorization: str | None = Header(None)):
 
133
  check_auth(authorization)
134
+ summary = run_conversational(disaster_pipe, req.report)
135
+ return {"summary": summary}
 
 
 
 
136
 
137
+ @app.post("/marketplace")
138
+ async def marketplace(req: MarketRequest, authorization: str | None = Header(None)):
 
139
  check_auth(authorization)
140
+ recommendation = run_conversational(market_pipe, req.product)
141
+ return {"recommendation": recommendation}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  @app.post("/vector-search")
144
  async def vector_search(req: VectorRequest, authorization: str | None = Header(None)):
145
  check_auth(authorization)
146
  try:
147
+ results = query_vector(req.query)
148
+ return {"results": results}
 
 
 
 
 
 
 
149
  except Exception as e:
150
+ logger.error(f"Vector search error: {e}")
151
+ return {"error": f"Vector search error: {str(e)}"}
152
+
153
+ @app.post("/crop-doctor")
154
+ async def crop_doctor(
155
+ symptoms: str = Header(...),
156
+ image: UploadFile = File(...),
157
+ authorization: str | None = Header(None)
158
+ ):
159
+ check_auth(authorization)
160
+ image_bytes = await image.read()
161
+ diagnosis = run_crop_doctor(image_bytes, symptoms)
162
+ return {"diagnosis": diagnosis}