lokeshloki143 commited on
Commit
0032334
·
verified ·
1 Parent(s): 7abc399

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -28
app.py CHANGED
@@ -1,57 +1,103 @@
1
 
2
- from flask import Flask, request, jsonify
3
  import logging
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
 
5
 
6
- app = Flask(__name__)
7
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
 
 
 
 
 
 
 
 
 
9
  try:
10
- # Load fine-tuned model and tokenizer
11
- model = AutoModelForSeq2SeqLM.from_pretrained("./fine_tuned_model")
12
- tokenizer = AutoTokenizer.from_pretrained("./fine_tuned_model")
13
- logging.info("Model and tokenizer loaded successfully")
14
  except Exception as e:
15
  logging.error(f"Failed to load model or tokenizer: {e}")
16
  raise e
17
 
18
- @app.route("/generate", methods=["POST"])
19
- def generate_checklist():
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- data = request.get_json()
22
- if not data:
23
- logging.warning("No JSON data provided in request")
24
- return jsonify({"error": "Invalid or missing JSON data"}), 400
25
-
26
- role = data.get("role", "Supervisor")
27
- project_id = data.get("project_id", "Unknown")
28
- project_name = data.get("project_name", "Unknown Project")
29
- milestones = data.get("milestones", "No milestones provided")
30
-
31
  # Prepare input for the model
32
- inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
33
  logging.info(f"Generating checklist for inputs: {inputs[:100]}...") # Truncate for logging
34
-
35
  # Tokenize and generate
36
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
37
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
38
  checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
-
40
- # Static tips (can be enhanced with model output)
41
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
42
-
43
  logging.info("Checklist and tips generated successfully")
44
- return jsonify({
45
  "checklist": checklist,
46
  "tips": tips
47
- }), 200
48
  except Exception as e:
49
  logging.error(f"Error generating checklist: {e}")
50
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  if __name__ == "__main__":
53
  try:
54
- app.run(host="0.0.0.0", port=8000, debug=False)
 
 
55
  except Exception as e:
56
- logging.error(f"Failed to start Flask server: {e}")
57
  raise e
 
1
 
2
+ import gradio as gr
3
  import logging
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ import os
6
+ from fastapi import FastAPI, Request
7
+ from pydantic import BaseModel
8
+ import uvicorn
9
 
10
+ # Configure logging
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
+ # Load Hugging Face token from environment variable
14
+ HF_TOKEN = os.getenv("HF_TOKEN")
15
+ if not HF_TOKEN:
16
+ logging.warning("HF_TOKEN not set. Private repositories may be inaccessible.")
17
+
18
+ # Model configuration
19
+ MODEL_PATH = "your_username/fine_tuned_bart_construction" # Replace with your model path
20
+
21
  try:
22
+ # Load model and tokenizer with authentication
23
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN)
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN)
25
+ logging.info(f"Model and tokenizer loaded successfully from {MODEL_PATH}")
26
  except Exception as e:
27
  logging.error(f"Failed to load model or tokenizer: {e}")
28
  raise e
29
 
30
+ # Define input model for FastAPI
31
+ class ChecklistInput(BaseModel):
32
+ role: str = "Supervisor"
33
+ project_id: str = "Unknown"
34
+ project_name: str = "Unknown Project"
35
+ milestones: str = "No milestones provided"
36
+
37
+ # Initialize FastAPI for JSON endpoint (Gradio mounts it)
38
+ app = FastAPI()
39
+
40
+ @app.post("/generate")
41
+ async def generate_checklist(data: ChecklistInput):
42
  try:
 
 
 
 
 
 
 
 
 
 
43
  # Prepare input for the model
44
+ inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {data.milestones}"
45
  logging.info(f"Generating checklist for inputs: {inputs[:100]}...") # Truncate for logging
46
+
47
  # Tokenize and generate
48
  input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
49
  outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
50
  checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+
52
+ # Static tips (can be enhanced with model)
53
  tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
54
+
55
  logging.info("Checklist and tips generated successfully")
56
+ return {
57
  "checklist": checklist,
58
  "tips": tips
59
+ }
60
  except Exception as e:
61
  logging.error(f"Error generating checklist: {e}")
62
+ return {"error": str(e)}
63
+
64
+ # Gradio interface function
65
+ def gradio_generate_checklist(role, project_id, project_name, milestones):
66
+ try:
67
+ inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
68
+ input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
69
+ outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
70
+ checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+ tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
72
+ return checklist, tips
73
+ except Exception as e:
74
+ return f"Error: {str(e)}", ""
75
+
76
+ # Define Gradio interface
77
+ iface = gr.Interface(
78
+ fn=gradio_generate_checklist,
79
+ inputs=[
80
+ gr.Textbox(label="Role", value="Supervisor"),
81
+ gr.Textbox(label="Project ID", value="P001"),
82
+ gr.Textbox(label="Project Name", value="Building A"),
83
+ gr.Textbox(label="Milestones", value="Complete foundation by 5/15")
84
+ ],
85
+ outputs=[
86
+ gr.Textbox(label="Checklist"),
87
+ gr.Textbox(label="Tips")
88
+ ],
89
+ title="AI Coach for Site Supervisors",
90
+ description="Generate daily checklists and tips for site supervisors."
91
+ )
92
+
93
+ # Mount FastAPI to Gradio
94
+ iface.app = app
95
 
96
  if __name__ == "__main__":
97
  try:
98
+ # Launch Gradio with public URL for Hugging Face Spaces
99
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
100
+ logging.info("Gradio interface launched successfully")
101
  except Exception as e:
102
+ logging.error(f"Failed to launch Gradio interface: {e}")
103
  raise e